diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4fdcb582..c6a47e9a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,10 +1,11 @@ +--- name: CI Pipeline -on: +on: # yamllint disable-line rule:truthy push: - branches: [ main, develop ] + branches: [main, develop] pull_request: - branches: [ main, develop ] + branches: [main, develop] env: PYTHON_VERSION: '3.9' @@ -16,7 +17,7 @@ jobs: backend: name: Backend CI runs-on: ubuntu-latest - + services: postgres: image: postgres:15-alpine @@ -31,7 +32,7 @@ jobs: --health-retries 5 ports: - 5432:5432 - + redis: image: redis:7-alpine options: >- @@ -60,21 +61,21 @@ jobs: ${{ runner.os }}-pip- - name: Install dependencies - working-directory: ./backend run: | python -m pip install --upgrade pip pip install -r requirements.txt - pip install pytest pytest-cov pytest-asyncio black flake8 mypy bandit safety + pip install pytest pytest-cov pytest-asyncio black flake8 mypy \ + bandit safety - name: Run linting working-directory: ./backend run: | echo "Running Black formatter check..." black --check app/ - + echo "Running Flake8 linter..." flake8 app/ --max-line-length=120 --extend-ignore=E203,W503 - + echo "Running type checking with mypy..." mypy app/ --ignore-missing-imports || true @@ -83,14 +84,15 @@ jobs: run: | echo "Running Bandit security linter..." bandit -r app/ -f json -o bandit-report.json || true - + echo "Checking dependencies for vulnerabilities..." safety check --json || true - name: Run tests working-directory: ./backend env: - DATABASE_URL: postgresql://openwatch:openwatch_test@localhost:5432/openwatch_test + DATABASE_URL: > + postgresql://openwatch:openwatch_test@localhost:5432/openwatch_test REDIS_URL: redis://localhost:6379 JWT_SECRET_KEY: test_secret_key_for_ci ENVIRONMENT: test @@ -98,7 +100,8 @@ jobs: # Check if tests directory exists if [ -d "tests" ] && [ "$(find tests -name '*.py' | head -1)" ]; then echo "Running pytest tests..." - pytest tests/ -v --cov=app --cov-report=xml --cov-report=html || echo "Some tests failed but continuing..." + pytest tests/ -v --cov=app --cov-report=xml \ + --cov-report=html || echo "Some tests failed but continuing..." else echo "No test files found in tests/ directory, skipping pytest" echo "This is normal for early development stages" @@ -113,13 +116,14 @@ jobs: - name: Build Docker image run: | - docker build -f docker/Dockerfile.backend -t openwatch-backend:${{ github.sha }} . + docker build -f docker/Dockerfile.backend \ + -t openwatch-backend:${{ github.sha }} . # Frontend Testing and Building frontend: name: Frontend CI runs-on: ubuntu-latest - + steps: - name: Checkout code uses: actions/checkout@v4 @@ -146,7 +150,7 @@ jobs: run: | echo "Running ESLint..." npm run lint || true - + echo "Running TypeScript type check..." npx tsc --noEmit @@ -168,14 +172,15 @@ jobs: - name: Build Docker image run: | - docker build -f docker/Dockerfile.frontend -t openwatch-frontend:${{ github.sha }} . + docker build -f docker/Dockerfile.frontend \ + -t openwatch-frontend:${{ github.sha }} . # Integration Tests integration: name: Integration Tests runs-on: ubuntu-latest needs: [backend, frontend] - + steps: - name: Checkout code uses: actions/checkout@v4 @@ -190,28 +195,29 @@ jobs: # Start services with docker compose echo "Starting services with docker-compose..." docker compose up -d - + # Wait for services to be ready echo "Waiting for services to be ready..." sleep 45 - + # Check if health check script exists and use it if [ -f "./scripts/production-health-check.sh" ]; then echo "Running health checks..." - ./scripts/production-health-check.sh --local || echo "Health checks completed with warnings" + ./scripts/production-health-check.sh --local || \ + echo "Health checks completed with warnings" else echo "Running basic connectivity tests..." # Check basic connectivity with retries for i in {1..3}; do if curl -f --max-time 10 http://localhost:3001 >/dev/null 2>&1; then - echo "Frontend connectivity: OK" + echo "Frontend OK" break else echo "Frontend connectivity attempt $i failed, retrying..." sleep 10 fi done - + for i in {1..3}; do if curl -f --max-time 10 http://localhost:8000/health >/dev/null 2>&1; then echo "Backend connectivity: OK" @@ -222,15 +228,15 @@ jobs: fi done fi - + # Show service status echo "Service status:" docker compose ps - + # Show logs for debugging echo "Recent logs:" docker compose logs --tail=20 - + # Clean up docker compose down -v @@ -239,7 +245,7 @@ jobs: name: E2E Tests runs-on: ubuntu-latest needs: [backend, frontend] - + services: postgres: image: postgres:15-alpine @@ -254,7 +260,7 @@ jobs: --health-retries 5 ports: - 5432:5432 - + redis: image: redis:7-alpine options: >- @@ -264,7 +270,7 @@ jobs: --health-retries 5 ports: - 6379:6379 - + steps: - name: Checkout code uses: actions/checkout@v4 @@ -288,7 +294,6 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} - name: Install backend dependencies - working-directory: ./backend run: | python -m pip install --upgrade pip pip install -r requirements.txt @@ -296,14 +301,15 @@ jobs: - name: Start backend service working-directory: ./backend env: - DATABASE_URL: postgresql://openwatch:openwatch_test@localhost:5432/openwatch_test + DATABASE_URL: > + postgresql://openwatch:openwatch_test@localhost:5432/openwatch_test REDIS_URL: redis://localhost:6379 JWT_SECRET_KEY: test_secret_key_for_e2e ENVIRONMENT: test run: | # Start backend in background python -m uvicorn app.main:app --host 0.0.0.0 --port 8000 & - + # Wait for backend to start timeout 60 bash -c 'until curl -f http://localhost:8000/health; do sleep 2; done' @@ -326,10 +332,10 @@ jobs: run: | # Start frontend in background npm run dev & - + # Wait for frontend to be ready timeout 60 bash -c 'until curl -f http://localhost:3001; do sleep 2; done' - + # Run E2E tests npx playwright test --reporter=html,junit @@ -356,7 +362,7 @@ jobs: runs-on: ubuntu-latest needs: [backend, frontend, integration, e2e] if: github.ref == 'refs/heads/main' && github.event_name == 'push' - + steps: - name: Checkout code uses: actions/checkout@v4 @@ -393,4 +399,4 @@ jobs: ghcr.io/${{ github.repository_owner }}/openwatch-frontend:latest ghcr.io/${{ github.repository_owner }}/openwatch-frontend:${{ github.sha }} cache-from: type=gha - cache-to: type=gha,mode=max \ No newline at end of file + cache-to: type=gha,mode=max diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index d942e733..427dee04 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -24,7 +24,7 @@ jobs: run: | python -m pip install --upgrade pip pip install black flake8 pylint mypy bandit vulture radon - cd backend && pip install -r requirements.txt + pip install -r requirements.txt - name: Black formatter run: | @@ -183,7 +183,7 @@ jobs: - name: Generate Python coverage working-directory: ./backend run: | - pip install -r requirements.txt + pip install -r ../requirements.txt pip install pytest pytest-cov pytest tests/ --cov=app --cov-report=xml --cov-report=term || true diff --git a/.github/workflows/test-ci-fixes.yml b/.github/workflows/test-ci-fixes.yml new file mode 100644 index 00000000..d2344402 --- /dev/null +++ b/.github/workflows/test-ci-fixes.yml @@ -0,0 +1,40 @@ +name: Test CI Fixes + +on: + workflow_dispatch: + +jobs: + test-backend-deps: + name: Test Backend Dependencies + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.9' + + - name: Test requirements.txt location + run: | + echo "Checking if requirements.txt exists at root..." + if [ -f requirements.txt ]; then + echo "✓ requirements.txt found at root" + else + echo "✗ requirements.txt not found at root" + exit 1 + fi + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + echo "✓ Dependencies installed successfully" + + - name: Test backend directory + run: | + echo "Checking backend directory structure..." + ls -la backend/ + echo "✓ Backend directory verified" \ No newline at end of file diff --git a/.gitignore b/.gitignore index 2d9e552e..ce46474e 100644 --- a/.gitignore +++ b/.gitignore @@ -562,7 +562,7 @@ nginx.conf # Gitignore analysis files gitignore_*.md -.github/ +# .github/ # Needed for GitHub Actions # ====================================== # RPM PACKAGE BUILDING FILES diff --git a/CI_PIPELINE_FIXES.md b/CI_PIPELINE_FIXES.md new file mode 100644 index 00000000..64c91fa3 --- /dev/null +++ b/CI_PIPELINE_FIXES.md @@ -0,0 +1,88 @@ +# CI Pipeline Fixes Report + +## Issues Identified + +### 1. Backend Dependencies Path Issue +**Problem**: The CI workflows were looking for `backend/requirements.txt`, but the file is located at the repository root. +**Status**: Fixed in `ci.yml` and `code-quality.yml` + +### 2. Missing Repository Secrets +**Problem**: SonarCloud analysis fails due to missing `SONAR_TOKEN` +**Action Required**: Add the following secrets to the repository: +- `SONAR_TOKEN` - Get from https://sonarcloud.io/account/security + +### 3. Code Quality Failures +**Problem**: Multiple code quality checks are failing +- Prettier formatting issues in frontend +- Python linting may fail on backend code + +**Recommended Actions**: +1. Run `npm run lint:fix` in frontend directory locally +2. Run `black backend/app/` to format Python code +3. Ensure all code passes linting before committing + +### 4. Container Security Scans +**Problem**: Trivy and Grype scans fail during Docker build +**Root Cause**: Dependencies installation failures cascade to Docker build failures +**Status**: Should be resolved once backend dependencies path is fixed + +### 5. Documentation Generation +**Problem**: Documentation jobs are failing +**Action Required**: Review documentation generation scripts and ensure dependencies are available + +## Changes Made + +1. Updated `ci.yml`: + - Fixed backend dependencies installation path (lines 62-66) + - Fixed E2E test backend dependencies path (lines 289-292) + +2. Updated `code-quality.yml`: + - Fixed Python dependencies installation path (lines 23-27) + +3. Created `test-ci-fixes.yml`: + - Added a manual workflow to test the fixes + +## Next Steps + +1. **Add Repository Secrets**: + ```bash + gh secret set SONAR_TOKEN --body "your-sonar-token" + ``` + +2. **Fix Code Quality Issues Locally**: + ```bash + # Frontend + cd frontend + npm run lint:fix + + # Backend + cd ../ + pip install black + black backend/app/ + ``` + +3. **Test the Fixes**: + - Commit these changes + - Create a PR to test the CI pipeline + - Run the test workflow: `gh workflow run test-ci-fixes.yml` + +4. **Monitor Results**: + - Check if backend dependencies install correctly + - Verify Docker builds succeed + - Ensure all security scans complete + +## Additional Recommendations + +1. Consider adding a `backend/requirements.txt` that references the root `requirements.txt`: + ``` + # backend/requirements.txt + -r ../requirements.txt + ``` + +2. Add workflow status badges to README.md to monitor CI health + +3. Set up branch protection rules to require CI passes before merging + +4. Configure Dependabot for automated dependency updates + +5. Add a pre-commit hook to run linting locally before commits \ No newline at end of file diff --git a/backend/app/audit_db.py b/backend/app/audit_db.py index 0a9aed6c..fcb4bfd5 100644 --- a/backend/app/audit_db.py +++ b/backend/app/audit_db.py @@ -2,6 +2,7 @@ Database Audit Logging Module Provides functions to write audit events directly to the database """ + from sqlalchemy.orm import Session from sqlalchemy import text from datetime import datetime @@ -11,7 +12,7 @@ logger = logging.getLogger(__name__) -async def log_audit_event( +def log_audit_event( db: Session, action: str, resource_type: str, @@ -19,11 +20,11 @@ async def log_audit_event( user_id: Optional[int] = None, ip_address: str = "0.0.0.0", user_agent: Optional[str] = None, - details: Optional[str] = None + details: Optional[str] = None, ) -> bool: """ Log an audit event to the database - + Args: db: Database session action: Action performed (e.g., LOGIN_SUCCESS, SCAN_CREATED) @@ -33,12 +34,13 @@ async def log_audit_event( ip_address: IP address of the client user_agent: User agent string (optional) details: Additional details about the event (optional) - + Returns: bool: True if successful, False otherwise """ try: - query = text(""" + query = text( + """ INSERT INTO audit_logs ( user_id, action, resource_type, resource_id, ip_address, user_agent, details, timestamp @@ -46,157 +48,165 @@ async def log_audit_event( :user_id, :action, :resource_type, :resource_id, :ip_address, :user_agent, :details, :timestamp ) - """) - - db.execute(query, { - "user_id": user_id, - "action": action, - "resource_type": resource_type, - "resource_id": resource_id, - "ip_address": ip_address, - "user_agent": user_agent, - "details": details, - "timestamp": datetime.utcnow() - }) - + """ + ) + + db.execute( + query, + { + "user_id": user_id, + "action": action, + "resource_type": resource_type, + "resource_id": resource_id, + "ip_address": ip_address, + "user_agent": user_agent, + "details": details, + "timestamp": datetime.utcnow(), + }, + ) + db.commit() return True - + except Exception as e: logger.error(f"Failed to log audit event: {e}") db.rollback() return False -async def log_login_event( +def log_login_event( db: Session, username: str, user_id: Optional[int], success: bool, ip_address: str, user_agent: Optional[str] = None, - failure_reason: Optional[str] = None + failure_reason: Optional[str] = None, ) -> bool: """Log login attempt to database""" action = "LOGIN_SUCCESS" if success else "LOGIN_FAILED" - details = f"User {username} logged in successfully" if success else f"Failed login attempt for {username}" + details = ( + f"User {username} logged in successfully" + if success + else f"Failed login attempt for {username}" + ) if failure_reason and not success: details += f" - Reason: {failure_reason}" - - return await log_audit_event( + + return log_audit_event( db=db, action=action, resource_type="auth", user_id=user_id if success else None, ip_address=ip_address, user_agent=user_agent, - details=details + details=details, ) -async def log_scan_event( +def log_scan_event( db: Session, action: str, scan_id: Optional[str], user_id: int, ip_address: str, host_name: Optional[str] = None, - details: Optional[str] = None + details: Optional[str] = None, ) -> bool: """Log scan-related events to database""" scan_details = details or f"Scan operation: {action}" if host_name: scan_details += f" on host {host_name}" - - return await log_audit_event( + + return log_audit_event( db=db, action=f"SCAN_{action.upper()}", resource_type="scan", resource_id=scan_id, user_id=user_id, ip_address=ip_address, - details=scan_details + details=scan_details, ) -async def log_host_event( +def log_host_event( db: Session, action: str, host_id: Optional[str], host_name: str, user_id: int, ip_address: str, - details: Optional[str] = None + details: Optional[str] = None, ) -> bool: """Log host-related events to database""" host_details = details or f"{action.title()} host: {host_name}" - - return await log_audit_event( + + return log_audit_event( db=db, action=f"HOST_{action.upper()}", resource_type="host", resource_id=host_id, user_id=user_id, ip_address=ip_address, - details=host_details + details=host_details, ) -async def log_user_event( +def log_user_event( db: Session, action: str, target_user_id: Optional[str], target_username: str, user_id: int, ip_address: str, - details: Optional[str] = None + details: Optional[str] = None, ) -> bool: """Log user management events to database""" user_details = details or f"{action.title()} user: {target_username}" - - return await log_audit_event( + + return log_audit_event( db=db, action=f"USER_{action.upper()}", resource_type="user", resource_id=target_user_id, user_id=user_id, ip_address=ip_address, - details=user_details + details=user_details, ) -async def log_security_event( +def log_security_event( db: Session, event_type: str, ip_address: str, user_id: Optional[int] = None, - details: Optional[str] = None + details: Optional[str] = None, ) -> bool: """Log security-related events to database""" - return await log_audit_event( + return log_audit_event( db=db, action=f"SECURITY_{event_type.upper()}", resource_type="security", user_id=user_id, ip_address=ip_address, - details=details + details=details, ) -async def log_admin_event( +def log_admin_event( db: Session, action: str, user_id: int, ip_address: str, resource_type: str = "system", - details: Optional[str] = None + details: Optional[str] = None, ) -> bool: """Log administrative actions to database""" - return await log_audit_event( + return log_audit_event( db=db, action=f"ADMIN_{action.upper()}", resource_type=resource_type, user_id=user_id, ip_address=ip_address, - details=details - ) \ No newline at end of file + details=details, + ) diff --git a/backend/app/auth.py b/backend/app/auth.py index 97607ac7..5b1b86e3 100644 --- a/backend/app/auth.py +++ b/backend/app/auth.py @@ -2,6 +2,7 @@ FIPS-compliant authentication and authorization system for OpenWatch Uses RSA-PSS signatures and secure password hashing """ + import os import jwt from datetime import datetime, timedelta @@ -29,7 +30,7 @@ argon2__time_cost=3, argon2__parallelism=1, argon2__hash_len=32, - argon2__salt_len=16 + argon2__salt_len=16, ) security = HTTPBearer() @@ -37,12 +38,12 @@ class FIPSJWTManager: """FIPS-compliant JWT token management using RSA-PSS""" - + def __init__(self): self.private_key = None self.public_key = None self._load_or_generate_keys() - + def _load_or_generate_keys(self): """Load existing RSA keys or generate new FIPS-compliant ones""" # Use relative paths for development, absolute for production @@ -53,168 +54,160 @@ def _load_or_generate_keys(self): else: private_key_path = "/app/security/keys/jwt_private.pem" public_key_path = "/app/security/keys/jwt_public.pem" - + try: # Try to load existing keys if os.path.exists(private_key_path) and os.path.exists(public_key_path): with open(private_key_path, "rb") as f: self.private_key = serialization.load_pem_private_key( - f.read(), - password=None, - backend=default_backend() + f.read(), password=None, backend=default_backend() ) - + with open(public_key_path, "rb") as f: self.public_key = serialization.load_pem_public_key( - f.read(), - backend=default_backend() + f.read(), backend=default_backend() ) logger.info("Loaded existing RSA keys for JWT signing") else: # Generate new FIPS-compliant RSA keys self._generate_keys(private_key_path, public_key_path) - + except Exception as e: logger.error(f"Error loading RSA keys: {e}") self._generate_keys(private_key_path, public_key_path) - + def _generate_keys(self, private_path: str, public_path: str): """Generate FIPS-compliant RSA-2048 key pair""" try: # Generate RSA-2048 key pair (FIPS approved) self.private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=default_backend() + public_exponent=65537, key_size=2048, backend=default_backend() ) self.public_key = self.private_key.public_key() - + # Ensure directory exists os.makedirs(os.path.dirname(private_path), exist_ok=True) - + # Save private key with open(private_path, "wb") as f: - f.write(self.private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() - )) - + f.write( + self.private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) + # Save public key with open(public_path, "wb") as f: - f.write(self.public_key.public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo - )) - + f.write( + self.public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + ) + # Set secure permissions os.chmod(private_path, 0o600) os.chmod(public_path, 0o644) - + logger.info("Generated new FIPS-compliant RSA keys for JWT signing") - + except Exception as e: logger.error(f"Failed to generate RSA keys: {e}") raise - - def create_access_token(self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: + + def create_access_token( + self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None + ) -> str: """Create JWT access token with RSA-PSS signature""" to_encode = data.copy() - + if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes) - - to_encode.update({ - "exp": expire, - "iat": datetime.utcnow(), - "jti": secrets.token_urlsafe(32) # JWT ID for revocation - }) - + + to_encode.update( + { + "exp": expire, + "iat": datetime.utcnow(), + "jti": secrets.token_urlsafe(32), # JWT ID for revocation + } + ) + try: # Use RS256 with FIPS-compliant RSA-PSS padding - token = jwt.encode( - to_encode, - self.private_key, - algorithm="RS256" - ) + token = jwt.encode(to_encode, self.private_key, algorithm="RS256") return token except Exception as e: logger.error(f"Failed to create JWT token: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Could not create access token" + detail="Could not create access token", ) - - def create_refresh_token(self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: + + def create_refresh_token( + self, data: Dict[str, Any], expires_delta: Optional[timedelta] = None + ) -> str: """Create JWT refresh token with longer expiration""" to_encode = data.copy() - + if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(days=settings.refresh_token_expire_days) - - to_encode.update({ - "exp": expire, - "iat": datetime.utcnow(), - "jti": secrets.token_urlsafe(32), - "type": "refresh" # Token type identifier - }) - + + to_encode.update( + { + "exp": expire, + "iat": datetime.utcnow(), + "jti": secrets.token_urlsafe(32), + "type": "refresh", # Token type identifier + } + ) + try: - token = jwt.encode( - to_encode, - self.private_key, - algorithm="RS256" - ) + token = jwt.encode(to_encode, self.private_key, algorithm="RS256") return token except Exception as e: logger.error(f"Failed to create refresh token: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Could not create refresh token" + detail="Could not create refresh token", ) - + def verify_token(self, token: str) -> Dict[str, Any]: """Verify JWT token with RSA-PSS signature""" try: - payload = jwt.decode( - token, - self.public_key, - algorithms=["RS256"] - ) + payload = jwt.decode(token, self.public_key, algorithms=["RS256"]) return payload except jwt.ExpiredSignatureError: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Token has expired" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired" ) except jwt.InvalidTokenError as e: logger.warning(f"Invalid token: {e}") raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials" ) - + def validate_access_token(self, token: str) -> Dict[str, Any]: """Validate access token specifically""" payload = self.verify_token(token) if payload.get("type") == "refresh": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Refresh token cannot be used for API access" + detail="Refresh token cannot be used for API access", ) return payload - + def validate_refresh_token(self, token: str) -> Dict[str, Any]: """Validate refresh token specifically""" payload = self.verify_token(token) if payload.get("type") != "refresh": raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid refresh token" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" ) return payload @@ -225,17 +218,17 @@ def validate_refresh_token(self, token: str) -> Dict[str, Any]: class PasswordManager: """FIPS-compliant password management""" - + @staticmethod def hash_password(password: str) -> str: """Hash password using Argon2id (FIPS approved)""" return pwd_context.hash(password) - + @staticmethod def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify password against hash""" return pwd_context.verify(plain_password, hashed_password) - + @staticmethod def generate_secure_password(length: int = 16) -> str: """Generate cryptographically secure password""" @@ -244,38 +237,38 @@ def generate_secure_password(length: int = 16) -> str: class SecurityAuditLogger: """Security event audit logging""" - + def __init__(self): self.audit_logger = logging.getLogger("openwatch.audit") handler = logging.FileHandler(settings.audit_log_file) - formatter = logging.Formatter( - '%(asctime)s - %(levelname)s - %(message)s' - ) + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) self.audit_logger.addHandler(handler) self.audit_logger.setLevel(logging.INFO) - + def log_login_attempt(self, username: str, success: bool, ip_address: str): """Log authentication attempts""" status = "SUCCESS" if success else "FAILED" - self.audit_logger.info( - f"LOGIN_{status} - User: {username}, IP: {ip_address}" - ) - + self.audit_logger.info(f"LOGIN_{status} - User: {username}, IP: {ip_address}") + def log_scan_action(self, username: str, action: str, target: str, ip_address: str): """Log scan-related actions""" self.audit_logger.info( f"SCAN_{action} - User: {username}, Target: {target}, IP: {ip_address}" ) - + def log_security_event(self, event_type: str, details: str, ip_address: str): """Log security events""" - self.audit_logger.warning( - f"SECURITY_{event_type} - Details: {details}, IP: {ip_address}" - ) - - async def log_api_key_action(self, user_id: str, action: str, api_key_id: str, - api_key_name: str, details: Optional[Dict] = None): + self.audit_logger.warning(f"SECURITY_{event_type} - Details: {details}, IP: {ip_address}") + + def log_api_key_action( + self, + user_id: str, + action: str, + api_key_id: str, + api_key_name: str, + details: Optional[Dict] = None, + ): """Log API key related actions""" self.audit_logger.info( f"API_KEY_{action} - User: {user_id}, Key: {api_key_name} ({api_key_id}), " @@ -286,15 +279,17 @@ async def log_api_key_action(self, user_id: str, action: str, api_key_id: str, audit_logger = SecurityAuditLogger() -async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict[str, Any]: +def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(security), +) -> Dict[str, Any]: """Get current authenticated user from JWT token or API key""" from sqlalchemy.orm import Session from .database import get_db, ApiKey import hashlib - + try: token = credentials.credentials - + # For development, allow demo token if token == "demo-token": return { @@ -302,9 +297,9 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(s "id": 1, "username": "admin", "email": "admin@openwatch.local", - "role": UserRole.SUPER_ADMIN.value + "role": UserRole.SUPER_ADMIN.value, } - + # Check if it's an API key (starts with "owk_") if token.startswith("owk_"): # Get database session @@ -312,30 +307,29 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(s try: # Hash the API key key_hash = hashlib.sha256(token.encode()).hexdigest() - + # Find the API key in database - api_key = db.query(ApiKey).filter( - ApiKey.key_hash == key_hash, - ApiKey.is_active == True - ).first() - + api_key = ( + db.query(ApiKey) + .filter(ApiKey.key_hash == key_hash, ApiKey.is_active == True) + .first() + ) + if not api_key: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API key" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" ) - + # Check expiration if api_key.expires_at and api_key.expires_at < datetime.utcnow(): raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="API key expired" + status_code=status.HTTP_401_UNAUTHORIZED, detail="API key expired" ) - + # Update last used timestamp api_key.last_used_at = datetime.utcnow() db.commit() - + # Return API key info as user context return { "sub": f"api_key_{api_key.id}", @@ -345,18 +339,17 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(s "role": "api_key", "permissions": api_key.permissions, "api_key_id": str(api_key.id), - "api_key_name": api_key.name + "api_key_name": api_key.name, } finally: db.close() - + # Otherwise, it's a JWT token payload = jwt_manager.verify_token(token) username = payload.get("sub") if username is None: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials" ) return payload except HTTPException: @@ -364,8 +357,7 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(s except Exception as e: logger.error(f"Authentication error: {e}") raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials" ) @@ -382,34 +374,28 @@ def decode_token(token: str) -> Optional[Dict[str, Any]]: "id": 1, "username": "admin", "email": "admin@openwatch.local", - "role": UserRole.SUPER_ADMIN.value + "role": UserRole.SUPER_ADMIN.value, } - + # Handle API keys if token.startswith("owk_"): # For middleware, we don't want to update database # Just return basic API key info - return { - "sub": "api_key", - "role": "api_key", - "username": "API Key", - "api_key": True - } - + return {"sub": "api_key", "role": "api_key", "username": "API Key", "api_key": True} + # Decode JWT token payload = jwt_manager.verify_token(token) return payload - + except Exception as e: logger.debug(f"Token decode failed: {e}") return None -async def require_admin(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]: +def require_admin(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]: """Require admin role for protected endpoints""" if current_user.get("role") != UserRole.SUPER_ADMIN.value: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Admin privileges required" + status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required" ) - return current_user \ No newline at end of file + return current_user diff --git a/backend/app/celery_app.py b/backend/app/celery_app.py index 976edd63..b86075ad 100644 --- a/backend/app/celery_app.py +++ b/backend/app/celery_app.py @@ -2,6 +2,7 @@ FIPS-compliant Celery configuration for secure task processing Redis with TLS and encrypted message passing """ + import os import ssl import logging @@ -16,27 +17,29 @@ logger = logging.getLogger(__name__) settings = get_settings() + # FIPS-compliant SSL context for Redis def create_redis_ssl_context(): """Create FIPS-compliant SSL context for Redis connections""" context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) - + # FIPS-approved settings context.minimum_version = ssl.TLSVersion.TLSv1_2 context.maximum_version = ssl.TLSVersion.TLSv1_3 - + # FIPS-approved cipher suites context.set_ciphers("ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20:!aNULL:!MD5:!DSS") - + # Certificate verification if settings.redis_ssl_ca: context.load_verify_locations(settings.redis_ssl_ca) - + if settings.redis_ssl_cert and settings.redis_ssl_key: context.load_cert_chain(settings.redis_ssl_cert, settings.redis_ssl_key) - + return context + # Redis connection configuration redis_ssl_context = create_redis_ssl_context() if settings.redis_ssl else None @@ -47,91 +50,85 @@ def create_redis_ssl_context(): # Create Celery app with FIPS-compliant configuration celery_app = Celery( - "openwatch", - broker=broker_url, - backend=broker_url, - include=[] # No tasks module for now + "openwatch", broker=broker_url, backend=broker_url, include=[] # No tasks module for now ) # FIPS-compliant Celery configuration celery_app.conf.update( # Security settings (Note: ssl_ciphers not supported by redis-py) - broker_use_ssl={ - "ssl_cert_reqs": ssl.CERT_REQUIRED, - "ssl_ca_certs": settings.redis_ssl_ca, - "ssl_certfile": settings.redis_ssl_cert, - "ssl_keyfile": settings.redis_ssl_key - } if settings.redis_ssl else None, - - redis_backend_use_ssl={ - "ssl_cert_reqs": ssl.CERT_REQUIRED, - "ssl_ca_certs": settings.redis_ssl_ca, - "ssl_certfile": settings.redis_ssl_cert, - "ssl_keyfile": settings.redis_ssl_key - } if settings.redis_ssl else None, - + broker_use_ssl=( + { + "ssl_cert_reqs": ssl.CERT_REQUIRED, + "ssl_ca_certs": settings.redis_ssl_ca, + "ssl_certfile": settings.redis_ssl_cert, + "ssl_keyfile": settings.redis_ssl_key, + } + if settings.redis_ssl + else None + ), + redis_backend_use_ssl=( + { + "ssl_cert_reqs": ssl.CERT_REQUIRED, + "ssl_ca_certs": settings.redis_ssl_ca, + "ssl_certfile": settings.redis_ssl_cert, + "ssl_keyfile": settings.redis_ssl_key, + } + if settings.redis_ssl + else None + ), # Task settings task_serializer="json", accept_content=["json"], result_serializer="json", timezone="UTC", enable_utc=True, - # Security and reliability task_reject_on_worker_lost=True, task_acks_late=True, worker_prefetch_multiplier=1, - # Task routing task_routes={ "backend.app.tasks.scan_host": {"queue": "scans"}, "backend.app.tasks.process_scan_result": {"queue": "results"}, - "backend.app.tasks.cleanup_old_files": {"queue": "maintenance"} + "backend.app.tasks.cleanup_old_files": {"queue": "maintenance"}, }, - # Queue configuration task_default_queue="default", task_queues=[ Queue("default", routing_key="default"), Queue("scans", routing_key="scans"), Queue("results", routing_key="results"), - Queue("maintenance", routing_key="maintenance") + Queue("maintenance", routing_key="maintenance"), ], - # Result backend settings result_expires=3600, # 1 hour - result_backend_transport_options={ - "retry_policy": { - "timeout": 5.0 - } - }, - + result_backend_transport_options={"retry_policy": {"timeout": 5.0}}, # Worker settings worker_max_tasks_per_child=1000, worker_disable_rate_limits=False, worker_send_task_events=True, task_send_sent_event=True, - # Security: Disable pickle serialization task_always_eager=False, - task_eager_propagates=True if settings.debug else False + task_eager_propagates=True if settings.debug else False, ) class SecureCeleryManager: """Secure Celery task management with audit logging""" - + def __init__(self): self.app = celery_app - - def submit_scan_task(self, scan_id: int, host_data: dict, content_data: dict, - profile_id: str, user_id: int) -> str: + + def submit_scan_task( + self, scan_id: int, host_data: dict, content_data: dict, profile_id: str, user_id: int + ) -> str: """Submit scan task with security validation""" try: # Validate inputs if not all([scan_id, host_data, content_data, profile_id, user_id]): raise ValueError("Missing required parameters for scan task") - + # Submit task task = self.app.send_task( "backend.app.tasks.scan_host", @@ -142,17 +139,17 @@ def submit_scan_task(self, scan_id: int, host_data: dict, content_data: dict, "max_retries": 3, "interval_start": 0, "interval_step": 0.2, - "interval_max": 0.2 - } + "interval_max": 0.2, + }, ) - + logger.info(f"Submitted scan task {task.id} for scan {scan_id}") return task.id - + except Exception as e: logger.error(f"Failed to submit scan task: {e}") raise - + def get_task_status(self, task_id: str) -> dict: """Get task status with security checks""" try: @@ -161,12 +158,12 @@ def get_task_status(self, task_id: str) -> dict: "task_id": task_id, "status": result.status, "result": result.result if result.ready() else None, - "traceback": result.traceback if result.failed() else None + "traceback": result.traceback if result.failed() else None, } except Exception as e: logger.error(f"Failed to get task status: {e}") return {"task_id": task_id, "status": "UNKNOWN", "error": str(e)} - + def revoke_task(self, task_id: str, terminate: bool = True) -> bool: """Revoke task with audit logging""" try: @@ -182,13 +179,14 @@ def revoke_task(self, task_id: str, terminate: bool = True) -> bool: celery_manager = SecureCeleryManager() -async def check_redis_health() -> bool: +def check_redis_health() -> bool: """Check Redis connectivity for health checks""" try: # Parse Redis URL import urllib.parse + parsed = urllib.parse.urlparse(settings.redis_url) - + # Create Redis connection redis_client = redis.Redis( host=parsed.hostname, @@ -200,14 +198,14 @@ async def check_redis_health() -> bool: ssl_certfile=settings.redis_ssl_cert if settings.redis_ssl else None, ssl_keyfile=settings.redis_ssl_key if settings.redis_ssl else None, socket_timeout=5, - socket_connect_timeout=5 + socket_connect_timeout=5, ) - + # Test connection redis_client.ping() redis_client.close() return True - + except Exception as e: logger.error(f"Redis health check failed: {e}") return False @@ -217,11 +215,12 @@ async def check_redis_health() -> bool: def worker_ready_handler(sender=None, **kwargs): """Handle worker ready signal""" logger.info(f"Celery worker ready: {sender}") - + # Log FIPS mode status if settings.fips_mode: try: from security.config.fips_config import FIPSConfig + fips_enabled = FIPSConfig.validate_fips_mode() logger.info(f"FIPS mode enabled: {fips_enabled}") except ImportError: @@ -235,4 +234,4 @@ def worker_shutdown_handler(sender=None, **kwargs): # Export Celery app for worker startup -__all__ = ["celery_app", "celery_manager", "check_redis_health"] \ No newline at end of file +__all__ = ["celery_app", "celery_manager", "check_redis_health"] diff --git a/backend/app/cli_interface.py b/backend/app/cli_interface.py index e2830fff..507e9061 100644 --- a/backend/app/cli_interface.py +++ b/backend/app/cli_interface.py @@ -19,184 +19,194 @@ # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) class OpenWatchCLI: """Main CLI interface for OpenWatch SCAP scanning""" - + def __init__(self): self.settings = get_settings() self.scanner = SCAPCLIScanner( content_dir=self.settings.scap_content_dir, results_dir=self.settings.scan_results_dir, - max_parallel_scans=100 # Support 100+ parallel scans + max_parallel_scans=100, # Support 100+ parallel scans ) - - async def scan_local(self, profile_id: str, content_path: str = None, - rule_id: str = None, output_file: str = None) -> int: + + async def scan_local( + self, + profile_id: str, + content_path: str = None, + rule_id: str = None, + output_file: str = None, + ) -> int: """Execute local SCAP scan""" try: print(f"[OpenWatch] Starting local scan with profile: {profile_id}") - + # Use default content if not specified if not content_path: content_path = self.scanner.get_default_content_path() print(f"[OpenWatch] Using default content: {content_path}") - + # Validate content file if not self.scanner.validate_content_file(content_path): print(f"[OpenWatch] ERROR: Invalid SCAP content file: {content_path}") return 1 - + # Configure host for local scan host_config = { - 'hostname': 'localhost', - 'port': 22, - 'username': 'root', - 'auth_method': 'local', - 'credential': '' + "hostname": "localhost", + "port": 22, + "username": "root", + "auth_method": "local", + "credential": "", } - + # Execute scan if rule_id: print(f"[OpenWatch] Scanning specific rule: {rule_id}") - + result = await self.scanner.scan_single_host( host_config, profile_id, content_path, rule_id ) - + # Display results self._print_scan_result(result) - + # Export results if requested if output_file: self.scanner.export_results_json([result], output_file) print(f"[OpenWatch] Results exported to: {output_file}") - - return 0 if result.get('status') == 'completed' else 1 - + + return 0 if result.get("status") == "completed" else 1 + except Exception as e: print(f"[OpenWatch] ERROR: Local scan failed: {e}") logger.error(f"Local scan error: {e}") return 1 - - async def scan_remote(self, targets: List[str], profile_id: str, - content_path: str = None, rule_id: str = None, - output_file: str = None, parallel: int = 5) -> int: + + async def scan_remote( + self, + targets: List[str], + profile_id: str, + content_path: str = None, + rule_id: str = None, + output_file: str = None, + parallel: int = 5, + ) -> int: """Execute remote SCAP scan on one or more hosts""" try: print(f"[OpenWatch] Starting remote scan on {len(targets)} target(s)") print(f"[OpenWatch] Profile: {profile_id}") print(f"[OpenWatch] Max parallel: {parallel}") - - # Use default content if not specified + + # Use default content if not specified if not content_path: content_path = self.scanner.get_default_content_path() print(f"[OpenWatch] Using default content: {content_path}") - + # Validate content file if not self.scanner.validate_content_file(content_path): print(f"[OpenWatch] ERROR: Invalid SCAP content file: {content_path}") return 1 - + # Note: For demo purposes, remote scanning needs proper credential management # In production, this would integrate with the credential storage system print("[OpenWatch] NOTE: Remote scanning requires SSH credentials") print("[OpenWatch] For demo purposes, showing scan initiation workflow") - + default_credentials = { - 'username': 'root', - 'auth_method': 'password', - 'credential': '' # Would be loaded from secure storage + "username": "root", + "auth_method": "password", + "credential": "", # Would be loaded from secure storage } - + # Update scanner concurrency self.scanner.max_parallel_scans = parallel - + if rule_id: print(f"[OpenWatch] Scanning specific rule: {rule_id}") - + # Execute batch scan results = await self.scanner.batch_scan_from_targets( targets, profile_id, content_path, rule_id, default_credentials ) - + # Display summary summary = self.scanner.generate_scan_summary(results) self._print_scan_summary(summary, targets) - + # Export results if requested if output_file: self.scanner.export_results_json(results, output_file) print(f"[OpenWatch] Results exported to: {output_file}") - + # Return success if all scans completed - successful_scans = summary['scan_summary']['successful_scans'] + successful_scans = summary["scan_summary"]["successful_scans"] return 0 if successful_scans == len(targets) else 1 - + except Exception as e: print(f"[OpenWatch] ERROR: Remote scan failed: {e}") logger.error(f"Remote scan error: {e}") return 1 - + def list_profiles(self, content_path: str = None) -> int: """List available SCAP profiles""" try: if not content_path: content_path = self.scanner.get_default_content_path() - + print(f"[OpenWatch] Listing profiles from: {content_path}") - + profiles = self.scanner.get_available_profiles(content_path) - + if not profiles: print("[OpenWatch] No profiles found in content file") return 1 - + print(f"\n[OpenWatch] Available Profiles ({len(profiles)}):") print("=" * 60) - + for i, profile in enumerate(profiles, 1): print(f"{i}. {profile.get('id', 'Unknown ID')}") print(f" Title: {profile.get('title', 'No title')}") print(f" Description: {profile.get('description', 'No description')[:100]}...") print() - + return 0 - + except Exception as e: print(f"[OpenWatch] ERROR: Failed to list profiles: {e}") return 1 - + def _print_scan_result(self, result: Dict): """Print formatted scan result""" - hostname = result.get('hostname', 'unknown') - status = result.get('status', 'unknown') - + hostname = result.get("hostname", "unknown") + status = result.get("status", "unknown") + print(f"\n[OpenWatch] Scan Results for {hostname}") print("=" * 50) print(f"Status: {status.upper()}") - - if 'rules_total' in result: + + if "rules_total" in result: print(f"Rules Total: {result['rules_total']}") print(f"Rules Passed: {result['rules_passed']}") print(f"Rules Failed: {result['rules_failed']}") print(f"Compliance Score: {result.get('score', 0):.1f}%") - - if result.get('error'): + + if result.get("error"): print(f"Error: {result['error']}") - + print() - + def _print_scan_summary(self, summary: Dict, targets: List[str]): """Print formatted scan summary""" - scan_sum = summary['scan_summary'] - comp_sum = summary['compliance_summary'] - + scan_sum = summary["scan_summary"] + comp_sum = summary["compliance_summary"] + print(f"\n[OpenWatch] Batch Scan Summary") print("=" * 50) print(f"Total Targets: {len(targets)}") @@ -215,7 +225,7 @@ def _print_scan_summary(self, summary: Dict, targets: List[str]): async def main(): """Main CLI entry point""" parser = argparse.ArgumentParser( - description='OpenWatch SCAP Compliance Scanner CLI', + description="OpenWatch SCAP Compliance Scanner CLI", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: @@ -223,58 +233,60 @@ async def main(): python cli_interface.py scan-remote --targets host1,host2,host3 --profile cis-ubuntu --parallel 10 python cli_interface.py list-profiles python cli_interface.py scan-local --profile custom --content /path/to/content.xml --rule specific_rule_id - """ + """, ) - - subparsers = parser.add_subparsers(dest='command', help='Available commands') - + + subparsers = parser.add_subparsers(dest="command", help="Available commands") + # Local scan command - local_parser = subparsers.add_parser('scan-local', help='Execute local SCAP scan') - local_parser.add_argument('--profile', '-p', required=True, help='SCAP profile ID') - local_parser.add_argument('--content', '-c', help='SCAP content file path') - local_parser.add_argument('--rule', '-r', help='Specific rule ID to scan') - local_parser.add_argument('--output', '-o', help='Output file for results (JSON)') - + local_parser = subparsers.add_parser("scan-local", help="Execute local SCAP scan") + local_parser.add_argument("--profile", "-p", required=True, help="SCAP profile ID") + local_parser.add_argument("--content", "-c", help="SCAP content file path") + local_parser.add_argument("--rule", "-r", help="Specific rule ID to scan") + local_parser.add_argument("--output", "-o", help="Output file for results (JSON)") + # Remote scan command - remote_parser = subparsers.add_parser('scan-remote', help='Execute remote SCAP scan') - remote_parser.add_argument('--targets', '-t', required=True, help='Comma-separated list of target hosts') - remote_parser.add_argument('--profile', '-p', required=True, help='SCAP profile ID') - remote_parser.add_argument('--content', '-c', help='SCAP content file path') - remote_parser.add_argument('--rule', '-r', help='Specific rule ID to scan') - remote_parser.add_argument('--parallel', type=int, default=5, help='Max parallel scans (default: 5)') - remote_parser.add_argument('--output', '-o', help='Output file for results (JSON)') - + remote_parser = subparsers.add_parser("scan-remote", help="Execute remote SCAP scan") + remote_parser.add_argument( + "--targets", "-t", required=True, help="Comma-separated list of target hosts" + ) + remote_parser.add_argument("--profile", "-p", required=True, help="SCAP profile ID") + remote_parser.add_argument("--content", "-c", help="SCAP content file path") + remote_parser.add_argument("--rule", "-r", help="Specific rule ID to scan") + remote_parser.add_argument( + "--parallel", type=int, default=5, help="Max parallel scans (default: 5)" + ) + remote_parser.add_argument("--output", "-o", help="Output file for results (JSON)") + # List profiles command - list_parser = subparsers.add_parser('list-profiles', help='List available SCAP profiles') - list_parser.add_argument('--content', '-c', help='SCAP content file path') - + list_parser = subparsers.add_parser("list-profiles", help="List available SCAP profiles") + list_parser.add_argument("--content", "-c", help="SCAP content file path") + args = parser.parse_args() - + if not args.command: parser.print_help() return 1 - + cli = OpenWatchCLI() - + try: - if args.command == 'scan-local': - return await cli.scan_local( - args.profile, args.content, args.rule, args.output - ) - - elif args.command == 'scan-remote': - targets = [t.strip() for t in args.targets.split(',') if t.strip()] + if args.command == "scan-local": + return await cli.scan_local(args.profile, args.content, args.rule, args.output) + + elif args.command == "scan-remote": + targets = [t.strip() for t in args.targets.split(",") if t.strip()] return await cli.scan_remote( targets, args.profile, args.content, args.rule, args.output, args.parallel ) - - elif args.command == 'list-profiles': + + elif args.command == "list-profiles": return cli.list_profiles(args.content) - + else: print(f"[OpenWatch] ERROR: Unknown command: {args.command}") return 1 - + except KeyboardInterrupt: print("\n[OpenWatch] Scan interrupted by user") return 1 @@ -284,6 +296,6 @@ async def main(): return 1 -if __name__ == '__main__': +if __name__ == "__main__": exit_code = asyncio.run(main()) - sys.exit(exit_code) \ No newline at end of file + sys.exit(exit_code) diff --git a/backend/app/config.py b/backend/app/config.py index f4236581..4b9c996c 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -2,6 +2,7 @@ OpenWatch Application Configuration FIPS-compliant security settings and environment configuration """ + import os from typing import Optional, List from pydantic import validator @@ -11,79 +12,79 @@ class Settings(BaseSettings): """Application settings with FIPS compliance""" - + # Application app_name: str = "OpenWatch" app_version: str = "1.2.0" debug: bool = False - + # Security secret_key: str algorithm: str = "RS256" # FIPS-approved RSA signature access_token_expire_minutes: int = 30 refresh_token_expire_days: int = 7 - + # Database (with TDE support) database_url: str database_ssl_mode: str = "require" database_ssl_cert: Optional[str] = None database_ssl_key: Optional[str] = None database_ssl_ca: Optional[str] = None - + # Redis/Celery (secure configuration) redis_url: str = "redis://localhost:6379" redis_ssl: bool = False # Disabled for Docker development redis_ssl_cert: Optional[str] = None redis_ssl_key: Optional[str] = None redis_ssl_ca: Optional[str] = None - + # OpenSCAP openscap_timeout: int = 3600 # 1 hour max scan time max_concurrent_scans: int = 5 scap_content_dir: str = os.getenv("SCAP_CONTENT_DIR", "/app/data/scap") scan_results_dir: str = os.getenv("SCAN_RESULTS_DIR", "/app/data/results") - + # FIPS Configuration fips_mode: bool = True master_key: str # For credential encryption - + # TLS/HTTPS tls_cert_file: Optional[str] = None tls_key_file: Optional[str] = None tls_ca_file: Optional[str] = None require_https: bool = True - + # Allowed hosts for CORS allowed_origins: List[str] = ["https://localhost:3001"] - + # File upload limits max_upload_size: int = 100 * 1024 * 1024 # 100MB allowed_file_types: List[str] = [".xml", ".zip", ".bz2", ".gz"] - + # Logging log_level: str = "INFO" log_file: Optional[str] = None audit_log_file: str = "/app/logs/audit.log" - + @validator("secret_key") def secret_key_must_be_strong(cls, v): if len(v) < 32: raise ValueError("Secret key must be at least 32 characters long") return v - + @validator("master_key") def master_key_must_be_strong(cls, v): if len(v) < 32: raise ValueError("Master key must be at least 32 characters long") return v - + @validator("allowed_origins") def validate_origins(cls, v): for origin in v: if not origin.startswith(("https://", "http://localhost")): raise ValueError("All origins must use HTTPS (except localhost)") return v - + class Config: env_file = ".env" env_prefix = "OPENWATCH_" @@ -113,7 +114,7 @@ def get_settings() -> Settings: "object-src 'none'" ), "Referrer-Policy": "strict-origin-when-cross-origin", - "Permissions-Policy": "geolocation=(), microphone=(), camera=()" + "Permissions-Policy": "geolocation=(), microphone=(), camera=()", } # FIPS-approved cipher suites for TLS @@ -123,5 +124,5 @@ def get_settings() -> Settings: "ECDHE-RSA-AES256-GCM-SHA384", "ECDHE-RSA-AES128-GCM-SHA256", "DHE-RSA-AES256-GCM-SHA384", - "DHE-RSA-AES128-GCM-SHA256" -] \ No newline at end of file + "DHE-RSA-AES128-GCM-SHA256", +] diff --git a/backend/app/database.py b/backend/app/database.py index 0cd7adec..a7bd13de 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -2,8 +2,24 @@ FIPS-compliant database configuration with encryption support PostgreSQL with TLS and encrypted connections """ + import logging -from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String, DateTime, Text, Boolean, LargeBinary, Float, JSON, ForeignKey, Enum +from sqlalchemy import ( + create_engine, + MetaData, + Table, + Column, + Integer, + String, + DateTime, + Text, + Boolean, + LargeBinary, + Float, + JSON, + ForeignKey, + Enum, +) from sqlalchemy.dialects.postgresql import UUID from uuid import uuid4 from sqlalchemy.ext.declarative import declarative_base @@ -27,12 +43,14 @@ ssl_params = {} if settings.database_ssl_mode and not settings.debug: # Only use SSL in production with certificates - ssl_params.update({ - "sslmode": settings.database_ssl_mode, - "sslcert": settings.database_ssl_cert, - "sslkey": settings.database_ssl_key, - "sslrootcert": settings.database_ssl_ca - }) + ssl_params.update( + { + "sslmode": settings.database_ssl_mode, + "sslcert": settings.database_ssl_cert, + "sslkey": settings.database_ssl_key, + "sslrootcert": settings.database_ssl_ca, + } + ) elif settings.debug: # Development mode - disable SSL ssl_params.update({"sslmode": "disable"}) @@ -45,11 +63,7 @@ max_overflow=20, pool_pre_ping=True, pool_recycle=3600, # Recycle connections every hour - connect_args={ - **ssl_params, - "connect_timeout": 10, - "options": "-c application_name=openwatch" - } + connect_args={**ssl_params, "connect_timeout": 10, "options": "-c application_name=openwatch"}, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -59,19 +73,32 @@ # Database Models class User(Base): """User model with secure password storage""" + __tablename__ = "users" - + id = Column(Integer, primary_key=True, index=True) username = Column(String(50), unique=True, index=True, nullable=False) email = Column(String(100), unique=True, index=True, nullable=False) hashed_password = Column(String(255), nullable=False) # Argon2id hash - role = Column(Enum('super_admin', 'security_admin', 'security_analyst', 'compliance_officer', 'auditor', 'guest', name='user_roles'), default='guest', nullable=False) + role = Column( + Enum( + "super_admin", + "security_admin", + "security_analyst", + "compliance_officer", + "auditor", + "guest", + name="user_roles", + ), + default="guest", + nullable=False, + ) is_active = Column(Boolean, default=True, nullable=False) created_at = Column(DateTime, default=datetime.utcnow, nullable=False) last_login = Column(DateTime, nullable=True) failed_login_attempts = Column(Integer, default=0, nullable=False) locked_until = Column(DateTime, nullable=True) - + # MFA Support mfa_enabled = Column(Boolean, default=False, nullable=False) mfa_secret = Column(Text, nullable=True) # Encrypted TOTP secret @@ -83,8 +110,9 @@ class User(Base): class MFAAuditLog(Base): """MFA audit log for security monitoring""" + __tablename__ = "mfa_audit_log" - + id = Column(Integer, primary_key=True, index=True) user_id = Column(Integer, ForeignKey("users.id"), nullable=False) action = Column(String(50), nullable=False) # enroll, validate, disable, etc. @@ -98,8 +126,9 @@ class MFAAuditLog(Base): class MFAUsedCodes(Base): """TOTP replay protection""" + __tablename__ = "mfa_used_codes" - + id = Column(Integer, primary_key=True, index=True) user_id = Column(Integer, ForeignKey("users.id"), nullable=False) code_hash = Column(String(64), nullable=False) # SHA-256 hash @@ -108,8 +137,9 @@ class MFAUsedCodes(Base): class Host(Base): """Host model with encrypted credential storage""" + __tablename__ = "hosts" - + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, index=True) # Native UUID hostname = Column(String(255), nullable=False) ip_address = Column(String(45), nullable=False) # IPv4 or IPv6 @@ -125,15 +155,18 @@ class Host(Base): tags = Column(String(500), nullable=True) # Added for bulk import (comma-separated) owner = Column(String(100), nullable=True) # Added for bulk import is_active = Column(Boolean, default=True, nullable=False) - created_by = Column(Integer, ForeignKey("users.id"), nullable=True) # Made optional for development + created_by = Column( + Integer, ForeignKey("users.id"), nullable=True + ) # Made optional for development created_at = Column(DateTime, default=datetime.utcnow, nullable=False) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False) class ScapContent(Base): """SCAP content metadata""" + __tablename__ = "scap_content" - + id = Column(Integer, primary_key=True, index=True) name = Column(String(100), nullable=False) filename = Column(String(255), nullable=False) @@ -149,28 +182,39 @@ class ScapContent(Base): class Scan(Base): """Scan job tracking""" + __tablename__ = "scans" - + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, index=True) # Native UUID name = Column(String(100), nullable=False) - host_id = Column(UUID(as_uuid=True), ForeignKey("hosts.id"), nullable=False) # Updated to match Host.id + host_id = Column( + UUID(as_uuid=True), ForeignKey("hosts.id"), nullable=False + ) # Updated to match Host.id content_id = Column(Integer, ForeignKey("scap_content.id"), nullable=False) profile_id = Column(String(100), nullable=False) - status = Column(String(20), default="pending", nullable=False) # pending, running, completed, failed + status = Column( + String(20), default="pending", nullable=False + ) # pending, running, completed, failed progress = Column(Integer, default=0, nullable=False) # 0-100 result_file = Column(String(500), nullable=True) report_file = Column(String(500), nullable=True) error_message = Column(Text, nullable=True) scan_options = Column(Text, nullable=True) # JSON options - started_by = Column(Integer, ForeignKey("users.id"), nullable=True) # Made optional for development + started_by = Column( + Integer, ForeignKey("users.id"), nullable=True + ) # Made optional for development started_at = Column(DateTime, default=datetime.utcnow, nullable=False) completed_at = Column(DateTime, nullable=True) celery_task_id = Column(String(100), nullable=True) - + # AEGIS Integration Fields remediation_requested = Column(Boolean, default=False, nullable=False) - aegis_remediation_id = Column(UUID(as_uuid=True), nullable=True) # Link to AEGIS remediation job - verification_scan = Column(Boolean, default=False, nullable=False) # True if this is a verification scan + aegis_remediation_id = Column( + UUID(as_uuid=True), nullable=True + ) # Link to AEGIS remediation job + verification_scan = Column( + Boolean, default=False, nullable=False + ) # True if this is a verification scan remediation_status = Column(String(20), nullable=True) # completed, failed, partial remediation_completed_at = Column(DateTime, nullable=True) scan_metadata = Column(JSON, nullable=True) # Additional metadata including remediation results @@ -178,8 +222,9 @@ class Scan(Base): class ScanResult(Base): """Scan results summary""" + __tablename__ = "scan_results" - + id = Column(Integer, primary_key=True, index=True) scan_id = Column(UUID(as_uuid=True), ForeignKey("scans.id"), nullable=False) total_rules = Column(Integer, nullable=False) @@ -197,8 +242,9 @@ class ScanResult(Base): class SystemCredentials(Base): """System-wide SSH credentials for enterprise environments""" + __tablename__ = "system_credentials" - + id = Column(Integer, primary_key=True, index=True) name = Column(String(100), nullable=False) # e.g., "Default Admin Account" description = Column(Text, nullable=True) @@ -221,8 +267,9 @@ class SystemCredentials(Base): class Role(Base): """Role definitions with permissions""" + __tablename__ = "roles" - + id = Column(Integer, primary_key=True, index=True) name = Column(String(50), unique=True, nullable=False) # super_admin, security_admin, etc. display_name = Column(String(100), nullable=False) # "Super Administrator" @@ -235,8 +282,9 @@ class Role(Base): class UserGroup(Base): """User groups for organizing access to hosts and resources""" + __tablename__ = "user_groups" - + id = Column(Integer, primary_key=True, index=True) name = Column(String(100), nullable=False) description = Column(Text, nullable=True) @@ -247,8 +295,9 @@ class UserGroup(Base): class UserGroupMembership(Base): """Many-to-many relationship between users and groups""" + __tablename__ = "user_group_memberships" - + id = Column(Integer, primary_key=True, index=True) user_id = Column(Integer, ForeignKey("users.id"), nullable=False) group_id = Column(Integer, ForeignKey("user_groups.id"), nullable=False) @@ -258,13 +307,16 @@ class UserGroupMembership(Base): class HostAccess(Base): """Host access control for users and groups""" + __tablename__ = "host_access" - + id = Column(Integer, primary_key=True, index=True) host_id = Column(UUID(as_uuid=True), ForeignKey("hosts.id"), nullable=False) user_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Direct user access group_id = Column(Integer, ForeignKey("user_groups.id"), nullable=True) # Group access - access_level = Column(Enum('read', 'write', 'admin', name='access_levels'), default='read', nullable=False) + access_level = Column( + Enum("read", "write", "admin", name="access_levels"), default="read", nullable=False + ) granted_by = Column(Integer, ForeignKey("users.id"), nullable=False) granted_at = Column(DateTime, default=datetime.utcnow, nullable=False) expires_at = Column(DateTime, nullable=True) # Optional expiration @@ -272,8 +324,9 @@ class HostAccess(Base): class HostGroup(Base): """Host groups for organizing hosts""" + __tablename__ = "host_groups" - + id = Column(Integer, primary_key=True, index=True) name = Column(String(100), nullable=False, unique=True) description = Column(Text, nullable=True) @@ -296,8 +349,9 @@ class HostGroup(Base): class HostGroupMembership(Base): """Host group membership mapping""" + __tablename__ = "host_group_memberships" - + id = Column(Integer, primary_key=True, index=True) host_id = Column(UUID(as_uuid=True), ForeignKey("hosts.id"), nullable=False) group_id = Column(Integer, ForeignKey("host_groups.id"), nullable=False) @@ -307,8 +361,9 @@ class HostGroupMembership(Base): class AuditLog(Base): """Security audit log""" + __tablename__ = "audit_logs" - + id = Column(Integer, primary_key=True, index=True) user_id = Column(Integer, nullable=True) # User ID if authenticated action = Column(String(50), nullable=False) @@ -322,8 +377,9 @@ class AuditLog(Base): class WebhookEndpoint(Base): """Webhook endpoint management for AEGIS integration""" + __tablename__ = "webhook_endpoints" - + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, index=True) name = Column(String(100), nullable=False) url = Column(String(500), nullable=False) @@ -337,13 +393,16 @@ class WebhookEndpoint(Base): class WebhookDelivery(Base): """Webhook delivery tracking""" + __tablename__ = "webhook_deliveries" - + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, index=True) webhook_id = Column(UUID(as_uuid=True), ForeignKey("webhook_endpoints.id"), nullable=False) event_type = Column(String(50), nullable=False) event_data = Column(JSON, nullable=False) - delivery_status = Column(String(20), default="pending", nullable=False) # pending, delivered, failed + delivery_status = Column( + String(20), default="pending", nullable=False + ) # pending, delivered, failed http_status_code = Column(Integer, nullable=True) response_body = Column(Text, nullable=True) error_message = Column(Text, nullable=True) @@ -356,8 +415,9 @@ class WebhookDelivery(Base): class ApiKey(Base): """API keys for service-to-service authentication""" + __tablename__ = "api_keys" - + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, index=True) name = Column(String(100), nullable=False) key_hash = Column(String(128), nullable=False) # Hashed API key @@ -371,8 +431,9 @@ class ApiKey(Base): class IntegrationAuditLog(Base): """Audit log for cross-service operations""" + __tablename__ = "integration_audit_log" - + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4, index=True) event_type = Column(String(50), nullable=False) # scan.completed, remediation.requested, etc. source_service = Column(String(20), nullable=False) # openwatch, aegis @@ -397,7 +458,7 @@ def get_db() -> Session: db.close() -async def create_tables(): +def create_tables(): """Create database tables if they don't exist""" try: Base.metadata.create_all(bind=engine) @@ -407,10 +468,11 @@ async def create_tables(): raise -async def check_database_health() -> bool: +def check_database_health() -> bool: """Check database connectivity for health checks""" try: from sqlalchemy import text + db = SessionLocal() # Simple query to test connection db.execute(text("SELECT 1")) @@ -423,38 +485,43 @@ async def check_database_health() -> bool: class DatabaseManager: """Database operations with security logging""" - + def __init__(self, db: Session): self.db = db - - def create_user(self, username: str, email: str, hashed_password: str, role: str = "user") -> User: + + def create_user( + self, username: str, email: str, hashed_password: str, role: str = "user" + ) -> User: """Create new user with audit logging""" - user = User( - username=username, - email=email, - hashed_password=hashed_password, - role=role - ) + user = User(username=username, email=email, hashed_password=hashed_password, role=role) self.db.add(user) self.db.commit() self.db.refresh(user) - + # Audit log self.log_audit("CREATE", "USER", str(user.id), f"Created user: {username}") - + return user - + def get_user_by_username(self, username: str) -> Optional[User]: """Get user by username""" return self.db.query(User).filter(User.username == username).first() - + def get_user_by_email(self, email: str) -> Optional[User]: """Get user by email""" return self.db.query(User).filter(User.email == email).first() - - def create_host(self, name: str, hostname: str, port: int, username: str, - auth_method: str, encrypted_credentials: bytes, - created_by: int, description: str = None) -> Host: + + def create_host( + self, + name: str, + hostname: str, + port: int, + username: str, + auth_method: str, + encrypted_credentials: bytes, + created_by: int, + description: str = None, + ) -> Host: """Create new host with encrypted credentials""" host = Host( name=name, @@ -464,19 +531,26 @@ def create_host(self, name: str, hostname: str, port: int, username: str, auth_method=auth_method, encrypted_credentials=encrypted_credentials, description=description, - created_by=created_by + created_by=created_by, ) self.db.add(host) self.db.commit() self.db.refresh(host) - + # Audit log self.log_audit("CREATE", "HOST", str(host.id), f"Created host: {name}") - + return host - - def log_audit(self, action: str, resource_type: str, resource_id: str, - details: str, user_id: int = None, ip_address: str = "unknown"): + + def log_audit( + self, + action: str, + resource_type: str, + resource_id: str, + details: str, + user_id: int = None, + ip_address: str = "unknown", + ): """Log audit event""" audit_log = AuditLog( user_id=user_id, @@ -484,7 +558,7 @@ def log_audit(self, action: str, resource_type: str, resource_id: str, resource_type=resource_type, resource_id=resource_id, ip_address=ip_address, - details=details + details=details, ) self.db.add(audit_log) self.db.commit() @@ -498,12 +572,12 @@ async def init_database(): healthy = await check_database_health() if not healthy: raise Exception("Database connection failed") - + # Create tables await create_tables() - + logger.info("Database initialized successfully with FIPS-compliant configuration") - + except Exception as e: logger.error(f"Database initialization failed: {e}") - raise \ No newline at end of file + raise diff --git a/backend/app/examples/group_scan_api_usage.py b/backend/app/examples/group_scan_api_usage.py index 5fe49646..e82a9af4 100644 --- a/backend/app/examples/group_scan_api_usage.py +++ b/backend/app/examples/group_scan_api_usage.py @@ -2,6 +2,7 @@ Example usage of the Group Scan Progress API Demonstrates how to use the new endpoints for tracking group scan progress """ + import asyncio import aiohttp import json @@ -11,43 +12,47 @@ class GroupScanAPIClient: """Example client for Group Scan API endpoints""" - + def __init__(self, base_url: str = "http://localhost:8000", auth_token: str = None): self.base_url = base_url self.auth_token = auth_token self.headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {auth_token}" if auth_token else "" + "Authorization": f"Bearer {auth_token}" if auth_token else "", } - + async def initiate_group_scan(self, group_id: int, scan_config: Dict[str, Any] = None) -> Dict: """Initiate a group scan for all hosts in a group""" url = f"{self.base_url}/api/host-groups/{group_id}/scan" - + async with aiohttp.ClientSession() as session: async with session.post(url, headers=self.headers, json=scan_config or {}) as response: if response.status == 200: return await response.json() else: error_text = await response.text() - raise Exception(f"Failed to initiate group scan: {response.status} - {error_text}") - + raise Exception( + f"Failed to initiate group scan: {response.status} - {error_text}" + ) + async def get_scan_progress(self, session_id: str) -> Dict: """Get real-time progress of a group scan""" url = f"{self.base_url}/api/host-groups/scan-sessions/{session_id}/progress" - + async with aiohttp.ClientSession() as session: async with session.get(url, headers=self.headers) as response: if response.status == 200: return await response.json() else: error_text = await response.text() - raise Exception(f"Failed to get scan progress: {response.status} - {error_text}") - + raise Exception( + f"Failed to get scan progress: {response.status} - {error_text}" + ) + async def get_host_scan_details(self, session_id: str) -> list: """Get detailed status of each host in a group scan""" url = f"{self.base_url}/api/host-groups/scan-sessions/{session_id}/hosts" - + async with aiohttp.ClientSession() as session: async with session.get(url, headers=self.headers) as response: if response.status == 200: @@ -55,23 +60,25 @@ async def get_host_scan_details(self, session_id: str) -> list: else: error_text = await response.text() raise Exception(f"Failed to get host details: {response.status} - {error_text}") - + async def cancel_group_scan(self, session_id: str) -> Dict: """Cancel an ongoing group scan""" url = f"{self.base_url}/api/host-groups/scan-sessions/{session_id}/cancel" - + async with aiohttp.ClientSession() as session: async with session.post(url, headers=self.headers) as response: if response.status == 200: return await response.json() else: error_text = await response.text() - raise Exception(f"Failed to cancel group scan: {response.status} - {error_text}") - + raise Exception( + f"Failed to cancel group scan: {response.status} - {error_text}" + ) + async def get_active_scans(self) -> list: """Get all active scan sessions""" url = f"{self.base_url}/api/host-groups/scan-sessions/active" - + async with aiohttp.ClientSession() as session: async with session.get(url, headers=self.headers) as response: if response.status == 200: @@ -79,11 +86,11 @@ async def get_active_scans(self) -> list: else: error_text = await response.text() raise Exception(f"Failed to get active scans: {response.status} - {error_text}") - + async def get_group_scan_summary(self, session_id: str) -> Dict: """Get comprehensive summary of a completed group scan""" url = f"{self.base_url}/api/host-groups/scan-sessions/{session_id}/summary" - + async with aiohttp.ClientSession() as session: async with session.get(url, headers=self.headers) as response: if response.status == 200: @@ -91,24 +98,28 @@ async def get_group_scan_summary(self, session_id: str) -> Dict: else: error_text = await response.text() raise Exception(f"Failed to get scan summary: {response.status} - {error_text}") - - async def list_scan_sessions(self, status: str = None, group_id: int = None, limit: int = 20) -> Dict: + + async def list_scan_sessions( + self, status: str = None, group_id: int = None, limit: int = 20 + ) -> Dict: """List group scan sessions with filtering""" url = f"{self.base_url}/api/host-groups/scan-sessions" params = {"limit": limit} - + if status: params["status"] = status if group_id: params["group_id"] = group_id - + async with aiohttp.ClientSession() as session: async with session.get(url, headers=self.headers, params=params) as response: if response.status == 200: return await response.json() else: error_text = await response.text() - raise Exception(f"Failed to list scan sessions: {response.status} - {error_text}") + raise Exception( + f"Failed to list scan sessions: {response.status} - {error_text}" + ) async def example_group_scan_workflow(): @@ -116,11 +127,8 @@ async def example_group_scan_workflow(): Example workflow demonstrating the complete group scan process """ # Initialize client (you would need a valid auth token) - client = GroupScanAPIClient( - base_url="http://localhost:8000", - auth_token="your_jwt_token_here" - ) - + client = GroupScanAPIClient(base_url="http://localhost:8000", auth_token="your_jwt_token_here") + try: # 1. Initiate a group scan print("🚀 Initiating group scan...") @@ -131,54 +139,60 @@ async def example_group_scan_workflow(): "priority": "high", "stagger_delay": 30, # 30 seconds between scans "max_concurrent": 3, # Maximum 3 concurrent scans - "email_notify": True + "email_notify": True, } - + scan_response = await client.initiate_group_scan(group_id, scan_config) session_id = scan_response["session_id"] - + print(f"✅ Group scan initiated successfully!") print(f" Session ID: {session_id}") print(f" Group: {scan_response['group_name']}") print(f" Total hosts: {scan_response['total_hosts']}") print(f" Status: {scan_response['status']}") print(f" Estimated completion: {scan_response.get('estimated_completion', 'Unknown')}") - + # 2. Monitor progress in real-time print(f"\n📊 Monitoring scan progress...") while True: progress = await client.get_scan_progress(session_id) - print(f" Progress: {progress['progress_percentage']:.1f}% " - f"({progress['hosts_completed']}/{progress['total_hosts']} hosts)") + print( + f" Progress: {progress['progress_percentage']:.1f}% " + f"({progress['hosts_completed']}/{progress['total_hosts']} hosts)" + ) print(f" Status: {progress['status']}") - print(f" Scanning: {progress['hosts_scanning']}, " - f"Pending: {progress['hosts_pending']}, " - f"Failed: {progress['hosts_failed']}") - + print( + f" Scanning: {progress['hosts_scanning']}, " + f"Pending: {progress['hosts_pending']}, " + f"Failed: {progress['hosts_failed']}" + ) + if progress["status"] in ["completed", "failed", "cancelled"]: break - + await asyncio.sleep(10) # Check every 10 seconds - + # 3. Get detailed host results print(f"\n🔍 Getting detailed host scan results...") host_details = await client.get_host_scan_details(session_id) - + for host in host_details: print(f" Host: {host['host_name']} ({host['hostname']})") print(f" Status: {host['status']}") - if host['scan_results']: - results = host['scan_results'] - print(f" Results: {results['passed_rules']}/{results['total_rules']} passed, " - f"Score: {results['score']}") - if host['error_message']: + if host["scan_results"]: + results = host["scan_results"] + print( + f" Results: {results['passed_rules']}/{results['total_rules']} passed, " + f"Score: {results['score']}" + ) + if host["error_message"]: print(f" Error: {host['error_message']}") print() - + # 4. Get comprehensive summary print(f"\n📈 Getting scan summary...") summary = await client.get_group_scan_summary(session_id) - + print(f" Final Status: {summary['status']}") print(f" Duration: {summary['scan_duration_minutes']} minutes") print(f" Successful scans: {summary['successful_scans']}/{summary['total_hosts']}") @@ -186,9 +200,9 @@ async def example_group_scan_workflow(): print(f" Average compliance score: {summary['average_compliance_score']}%") print(f" Total rules checked: {summary['total_rules_checked']}") print(f" Total failed rules: {summary['total_failed_rules']}") - + print(f"\n✅ Group scan workflow completed successfully!") - + except Exception as e: print(f"❌ Error in group scan workflow: {e}") @@ -198,29 +212,33 @@ async def example_scan_management(): Example of scan session management operations """ client = GroupScanAPIClient(auth_token="your_jwt_token_here") - + try: # List all active scans print("📋 Getting active scans...") active_scans = await client.get_active_scans() - + if active_scans: print(f"Found {len(active_scans)} active scans:") for scan in active_scans: - print(f" {scan['session_id']}: {scan['group_name']} " - f"({scan['progress_percentage']:.1f}% complete)") + print( + f" {scan['session_id']}: {scan['group_name']} " + f"({scan['progress_percentage']:.1f}% complete)" + ) else: print("No active scans found.") - + # List recent scan sessions print(f"\n📜 Getting recent scan sessions...") sessions = await client.list_scan_sessions(limit=10) - + print(f"Found {len(sessions['sessions'])} recent sessions:") for session in sessions["sessions"]: - print(f" {session['session_id'][:8]}... - {session['group_name']} " - f"({session['status']}, {session['progress_percentage']:.1f}%)") - + print( + f" {session['session_id'][:8]}... - {session['group_name']} " + f"({session['status']}, {session['progress_percentage']:.1f}%)" + ) + except Exception as e: print(f"❌ Error in scan management: {e}") @@ -228,20 +246,20 @@ async def example_scan_management(): if __name__ == "__main__": print("🔧 Group Scan API Example Usage") print("=" * 50) - + # Note: This example requires a running OpenWatch backend with valid authentication # Replace 'your_jwt_token_here' with an actual JWT token - + print("\n⚠️ This example requires:") print(" - OpenWatch backend running on localhost:8000") print(" - Valid JWT authentication token") print(" - At least one host group with active hosts") print(" - SCAP content uploaded to the system") - + # Uncomment to run the examples (with proper authentication) # asyncio.run(example_group_scan_workflow()) # asyncio.run(example_scan_management()) - + print("\n💡 API Endpoints implemented:") print(" POST /api/host-groups/{group_id}/scan - Initiate group scan") print(" GET /api/host-groups/scan-sessions/{session_id}/progress - Get progress") @@ -249,4 +267,4 @@ async def example_scan_management(): print(" POST /api/host-groups/scan-sessions/{session_id}/cancel - Cancel scan") print(" GET /api/host-groups/scan-sessions/active - Get active scans") print(" GET /api/host-groups/scan-sessions/{session_id}/summary - Get summary") - print(" GET /api/host-groups/scan-sessions - List scan sessions") \ No newline at end of file + print(" GET /api/host-groups/scan-sessions - List scan sessions") diff --git a/backend/app/init_admin.py b/backend/app/init_admin.py index 3c19c923..bd31d6b4 100644 --- a/backend/app/init_admin.py +++ b/backend/app/init_admin.py @@ -9,7 +9,9 @@ from rbac import UserRole # Database URL -DATABASE_URL = os.getenv("OPENWATCH_DATABASE_URL", "postgresql://openwatch:OpenWatch2025@localhost:5432/openwatch") +DATABASE_URL = os.getenv( + "OPENWATCH_DATABASE_URL", "postgresql://openwatch:OpenWatch2025@localhost:5432/openwatch" +) # Password hasher pwd_context = CryptContext( @@ -20,32 +22,39 @@ argon2__parallelism=1, ) + def create_admin_user(): """Create default admin user if it doesn't exist""" engine = create_engine(DATABASE_URL) - + with engine.connect() as conn: # Check if admin user exists result = conn.execute(text("SELECT id FROM users WHERE username = 'admin'")) if result.fetchone(): print("Admin user already exists") return - + # Create admin user hashed_password = pwd_context.hash("admin123") - conn.execute(text(""" + conn.execute( + text( + """ INSERT INTO users (username, email, hashed_password, role, is_active, created_at, failed_login_attempts, mfa_enabled) VALUES ('admin', 'admin@example.com', :password, :role, true, CURRENT_TIMESTAMP, 0, false) - """), {"password": hashed_password, "role": UserRole.SUPER_ADMIN.value}) + """ + ), + {"password": hashed_password, "role": UserRole.SUPER_ADMIN.value}, + ) conn.commit() - + print("Admin user created successfully") print("Username: admin") print("Password: admin123") + if __name__ == "__main__": try: create_admin_user() except Exception as e: print(f"Error: {e}") - sys.exit(1) \ No newline at end of file + sys.exit(1) diff --git a/backend/app/init_roles.py b/backend/app/init_roles.py index 0c580157..33c98456 100644 --- a/backend/app/init_roles.py +++ b/backend/app/init_roles.py @@ -1,6 +1,7 @@ """ Initialize roles and permissions in the database """ + import asyncio from sqlalchemy.orm import Session from sqlalchemy import text @@ -16,75 +17,90 @@ def init_roles(db: Session): """Initialize roles in the database""" - + role_definitions = { UserRole.SUPER_ADMIN: { "display_name": "Super Administrator", - "description": "Full system access with user management capabilities" + "description": "Full system access with user management capabilities", }, UserRole.SECURITY_ADMIN: { - "display_name": "Security Administrator", - "description": "Security-focused administration without user management" + "display_name": "Security Administrator", + "description": "Security-focused administration without user management", }, UserRole.SECURITY_ANALYST: { "display_name": "Security Analyst", - "description": "Day-to-day security operations and scan execution" + "description": "Day-to-day security operations and scan execution", }, UserRole.COMPLIANCE_OFFICER: { "display_name": "Compliance Officer", - "description": "Compliance reporting and read-only access to results" + "description": "Compliance reporting and read-only access to results", }, UserRole.AUDITOR: { - "display_name": "Auditor", - "description": "External audit support with read-only access" + "display_name": "Auditor", + "description": "External audit support with read-only access", }, UserRole.GUEST: { "display_name": "Guest", - "description": "Limited read-only access to assigned resources" - } + "description": "Limited read-only access to assigned resources", + }, } - + try: for role_name, role_info in role_definitions.items(): # Check if role already exists - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id FROM roles WHERE name = :name - """), {"name": role_name.value}) - + """ + ), + {"name": role_name.value}, + ) + if result.fetchone(): logger.info(f"Role {role_name.value} already exists, updating permissions...") # Update existing role permissions permissions_json = json.dumps([p.value for p in ROLE_PERMISSIONS[role_name]]) - db.execute(text(""" + db.execute( + text( + """ UPDATE roles SET permissions = :permissions, display_name = :display_name, description = :description, updated_at = CURRENT_TIMESTAMP WHERE name = :name - """), { - "name": role_name.value, - "permissions": permissions_json, - "display_name": role_info["display_name"], - "description": role_info["description"] - }) + """ + ), + { + "name": role_name.value, + "permissions": permissions_json, + "display_name": role_info["display_name"], + "description": role_info["description"], + }, + ) else: logger.info(f"Creating role {role_name.value}...") # Create new role permissions_json = json.dumps([p.value for p in ROLE_PERMISSIONS[role_name]]) - db.execute(text(""" + db.execute( + text( + """ INSERT INTO roles (name, display_name, description, permissions, is_active, created_at, updated_at) VALUES (:name, :display_name, :description, :permissions, true, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) - """), { - "name": role_name.value, - "display_name": role_info["display_name"], - "description": role_info["description"], - "permissions": permissions_json - }) - + """ + ), + { + "name": role_name.value, + "display_name": role_info["display_name"], + "description": role_info["description"], + "permissions": permissions_json, + }, + ) + db.commit() logger.info("Roles initialized successfully") - + except Exception as e: logger.error(f"Error initializing roles: {e}") db.rollback() @@ -97,10 +113,10 @@ def create_default_super_admin(db: Session): # Check if there's already a user with ID 1 result = db.execute(text("SELECT id, role FROM users WHERE id = 1")) existing_user = result.fetchone() - + if existing_user: # Update existing user to super_admin role - if existing_user.role != 'super_admin': + if existing_user.role != "super_admin": db.execute(text("UPDATE users SET role = 'super_admin' WHERE id = 1")) logger.info("Updated existing user (ID=1) to super_admin role") else: @@ -108,16 +124,22 @@ def create_default_super_admin(db: Session): else: # Create new super admin user from .auth import pwd_context + hashed_password = pwd_context.hash("admin123") # Default password - should be changed - - db.execute(text(""" + + db.execute( + text( + """ INSERT INTO users (id, username, email, hashed_password, role, is_active, created_at, failed_login_attempts, mfa_enabled) VALUES (1, 'admin', 'admin@example.com', :password, 'super_admin', true, CURRENT_TIMESTAMP, 0, false) - """), {"password": hashed_password}) + """ + ), + {"password": hashed_password}, + ) logger.info("Created new super admin user (username: admin, password: admin123)") - + db.commit() - + except Exception as e: logger.error(f"Error creating default super admin: {e}") db.rollback() @@ -128,29 +150,37 @@ def init_default_system_credentials(db: Session): """Initialize default system SSH credentials for frictionless onboarding""" try: # Check if any system credentials already exist - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT COUNT(*) as count FROM system_credentials WHERE is_active = true - """)) - + """ + ) + ) + existing_count = result.fetchone().count - + if existing_count > 0: - logger.info(f"Found {existing_count} existing system credentials, skipping initialization") + logger.info( + f"Found {existing_count} existing system credentials, skipping initialization" + ) return - + logger.info("No system credentials found - creating placeholder credentials for easy setup") - + # Create placeholder credentials that guide users to configure actual credentials placeholder_description = ( "Default placeholder credentials - PLEASE UPDATE with your actual SSH credentials. " "This entry provides a starting point for SSH-based scanning and monitoring. " "Update the username, password, or SSH key to match your environment." ) - + current_time = datetime.utcnow() - + # Insert placeholder credentials (no actual sensitive data) - db.execute(text(""" + db.execute( + text( + """ INSERT INTO system_credentials (name, description, username, auth_method, encrypted_password, encrypted_private_key, private_key_passphrase, is_default, is_active, @@ -158,26 +188,33 @@ def init_default_system_credentials(db: Session): VALUES (:name, :description, :username, :auth_method, :encrypted_password, :encrypted_private_key, :private_key_passphrase, :is_default, :is_active, :created_by, :created_at, :updated_at) - """), { - "name": "Setup Required - Default SSH Credentials", - "description": placeholder_description, - "username": "root", - "auth_method": "password", - "encrypted_password": encrypt_data(b"CHANGE_ME_PLEASE"), # Obvious placeholder - "encrypted_private_key": None, - "private_key_passphrase": None, - "is_default": True, - "is_active": True, - "created_by": 1, # Created by default admin user - "created_at": current_time, - "updated_at": current_time - }) - + """ + ), + { + "name": "Setup Required - Default SSH Credentials", + "description": placeholder_description, + "username": "root", + "auth_method": "password", + "encrypted_password": encrypt_data(b"CHANGE_ME_PLEASE"), # Obvious placeholder + "encrypted_private_key": None, + "private_key_passphrase": None, + "is_default": True, + "is_active": True, + "created_by": 1, # Created by default admin user + "created_at": current_time, + "updated_at": current_time, + }, + ) + db.commit() - - logger.info("Created placeholder system credentials - users should update these in Settings") - logger.warning("SECURITY NOTICE: Default SSH credentials created with placeholder password. Users must update these credentials in Settings before performing SSH operations.") - + + logger.info( + "Created placeholder system credentials - users should update these in Settings" + ) + logger.warning( + "SECURITY NOTICE: Default SSH credentials created with placeholder password. Users must update these credentials in Settings before performing SSH operations." + ) + except Exception as e: logger.error(f"Error creating default system credentials: {e}") db.rollback() @@ -189,7 +226,7 @@ async def initialize_rbac_system(): try: # Ensure tables exist await create_tables() - + # Initialize roles and system components db = SessionLocal() try: @@ -199,7 +236,7 @@ async def initialize_rbac_system(): logger.info("RBAC system and default credentials initialized successfully") finally: db.close() - + except Exception as e: logger.error(f"Failed to initialize RBAC system: {e}") raise @@ -207,4 +244,4 @@ async def initialize_rbac_system(): if __name__ == "__main__": # Run initialization - asyncio.run(initialize_rbac_system()) \ No newline at end of file + asyncio.run(initialize_rbac_system()) diff --git a/backend/app/main.py b/backend/app/main.py index 5c4050be..a6458f78 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -2,6 +2,7 @@ OpenWatch FastAPI Application - FIPS Compliant Security Scanner Main application with comprehensive security middleware """ + import logging import os import asyncio @@ -17,9 +18,32 @@ from .config import get_settings, SECURITY_HEADERS from .auth import jwt_manager, audit_logger from .database import engine, create_tables, get_db -from .routes import auth, hosts, scans, content, scap_content, monitoring, users, audit, host_groups, scan_templates, webhooks, mfa +from .routes import ( + auth, + hosts, + scans, + content, + scap_content, + monitoring, + users, + audit, + host_groups, + scan_templates, + webhooks, + mfa, +) from .routes.system_settings_unified import router as system_settings_router -from .routes import credentials, api_keys, remediation_callback, integration_metrics, bulk_operations, compliance, rule_scanning, capabilities +from .routes import ( + credentials, + api_keys, + remediation_callback, + integration_metrics, + bulk_operations, + compliance, + rule_scanning, + capabilities, +) + # Import security routes only if available try: from .routes import automated_fixes @@ -38,12 +62,12 @@ from .middleware.metrics import PrometheusMiddleware, background_updater from .middleware.rate_limiting import get_rate_limiting_middleware from .services.prometheus_metrics import get_metrics_instance + # from .services.tracing import initialize_tracing, instrument_fastapi_app, instrument_database_engine # Disabled for now # Configure logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @@ -55,61 +79,68 @@ async def lifespan(app: FastAPI): """Application lifespan management""" # Startup logger.info("Starting OpenWatch application...") - + # Verify FIPS mode if required if settings.fips_mode: try: from security.config.fips_config import FIPSConfig + if not FIPSConfig.validate_fips_mode(): logger.warning("FIPS mode is not enabled in the system") else: logger.info("FIPS mode validated successfully") except ImportError: logger.warning("FIPS configuration module not found - using development mode") - + # Create database tables with retry logic (skip in development if fails) max_retries = 3 retry_delay = 5 - + for attempt in range(max_retries): try: await create_tables() logger.info("Database tables created successfully") - + # Initialize RBAC system try: from .init_roles import initialize_rbac_system + await initialize_rbac_system() logger.info("RBAC system initialized successfully") except Exception as rbac_error: logger.warning(f"RBAC initialization failed: {rbac_error}") if not settings.debug: raise - + # Initialize scheduler state from database try: from .routes.system_settings_unified import restore_scheduler_state + await restore_scheduler_state() logger.info("Scheduler state restored from database") except Exception as scheduler_error: logger.warning(f"Scheduler restoration failed: {scheduler_error}") # Don't raise - scheduler can be started manually from UI - + break except Exception as e: if attempt < max_retries - 1: - logger.warning(f"Database connection attempt {attempt + 1} failed: {e}. Retrying in {retry_delay} seconds...") + logger.warning( + f"Database connection attempt {attempt + 1} failed: {e}. Retrying in {retry_delay} seconds..." + ) await asyncio.sleep(retry_delay) else: if settings.debug: - logger.warning(f"Database connection failed in debug mode, continuing without DB: {e}") + logger.warning( + f"Database connection failed in debug mode, continuing without DB: {e}" + ) else: logger.error(f"Failed to connect to database after {max_retries} attempts: {e}") raise - + # Initialize JWT keys logger.info("JWT manager initialized with RSA keys") - + # Initialize distributed tracing (disabled for now) # try: # tracing_success = initialize_tracing( @@ -125,19 +156,20 @@ async def lifespan(app: FastAPI): # except Exception as e: # logger.warning(f"Failed to initialize distributed tracing: {e}") logger.info("Distributed tracing disabled for initial deployment") - + # Start background metrics collection try: import asyncio + asyncio.create_task(background_updater.start_background_updates()) logger.info("Background metrics collection started") except Exception as e: logger.warning(f"Failed to start background metrics collection: {e}") - + logger.info("OpenWatch application started successfully") - + yield - + # Shutdown logger.info("Shutting down OpenWatch application...") background_updater.stop_background_updates() @@ -151,7 +183,7 @@ async def lifespan(app: FastAPI): version="1.2.0", docs_url="/docs" if settings.debug else None, redoc_url="/redoc" if settings.debug else None, - lifespan=lifespan + lifespan=lifespan, ) @@ -161,12 +193,13 @@ async def lifespan(app: FastAPI): rate_limiter = get_rate_limiting_middleware() app.middleware("http")(rate_limiter) + # Security Middleware @app.middleware("http") async def security_headers_middleware(request: Request, call_next): """Add FIPS-compliant security headers to all responses""" response = await call_next(request) - + # Add security headers with development modifications for header, value in SECURITY_HEADERS.items(): if header == "Content-Security-Policy" and settings.debug: @@ -184,7 +217,7 @@ async def security_headers_middleware(request: Request, call_next): response.headers[header] = dev_csp else: response.headers[header] = value - + return response @@ -192,97 +225,97 @@ async def security_headers_middleware(request: Request, call_next): async def audit_middleware(request: Request, call_next): """Log security-relevant requests for audit purposes""" start_time = time.time() - + # Get client IP client_ip = request.client.host if "x-forwarded-for" in request.headers: client_ip = request.headers["x-forwarded-for"].split(",")[0].strip() - + # Process request response = await call_next(request) - + # Log security events (only for non-auth endpoints to avoid double logging) process_time = time.time() - start_time - + # Get database session for audit logging db = next(get_db()) - + try: # Log scan operations if request.url.path.startswith("/api/scans"): audit_logger.log_security_event( "SCAN_OPERATION", f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}", - client_ip + client_ip, ) await log_security_event( db=db, event_type="SCAN_OPERATION", ip_address=client_ip, - details=f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}" + details=f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}", ) - + # Log host operations elif request.url.path.startswith("/api/hosts"): audit_logger.log_security_event( "HOST_OPERATION", f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}", - client_ip + client_ip, ) await log_security_event( db=db, event_type="HOST_OPERATION", ip_address=client_ip, - details=f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}" + details=f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}", ) - + # Log user management operations elif request.url.path.startswith("/api/users"): audit_logger.log_security_event( "USER_OPERATION", f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}", - client_ip + client_ip, ) await log_security_event( db=db, event_type="USER_OPERATION", ip_address=client_ip, - details=f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}" + details=f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}", ) - + # Log webhook operations elif request.url.path.startswith("/api/v1/webhooks"): audit_logger.log_security_event( "WEBHOOK_OPERATION", f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}", - client_ip + client_ip, ) await log_security_event( db=db, event_type="WEBHOOK_OPERATION", ip_address=client_ip, - details=f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}" + details=f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}", ) - + # Log unusual status codes (HTTP errors) if response.status_code >= 400: audit_logger.log_security_event( "HTTP_ERROR", f"Path: {request.url.path}, Method: {request.method}, Status: {response.status_code}", - client_ip + client_ip, ) await log_security_event( db=db, event_type="HTTP_ERROR", ip_address=client_ip, - details=f"HTTP {response.status_code} error on {request.url.path}" + details=f"HTTP {response.status_code} error on {request.url.path}", ) - + except Exception as e: logger.error(f"Error in audit middleware: {e}") finally: db.close() - + return response @@ -293,10 +326,9 @@ async def https_redirect_middleware(request: Request, call_next): if request.url.scheme != "https": https_url = request.url.replace(scheme="https") return JSONResponse( - status_code=status.HTTP_301_MOVED_PERMANENTLY, - headers={"Location": str(https_url)} + status_code=status.HTTP_301_MOVED_PERMANENTLY, headers={"Location": str(https_url)} ) - + return await call_next(request) @@ -312,7 +344,7 @@ async def https_redirect_middleware(request: Request, call_next): allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE"], allow_headers=["Authorization", "Content-Type"], - expose_headers=["X-Total-Count"] + expose_headers=["X-Total-Count"], ) # Trusted Host Middleware @@ -324,10 +356,7 @@ async def https_redirect_middleware(request: Request, call_next): host = origin.replace("https://", "").split(":")[0] trusted_hosts.append(host) -app.add_middleware( - TrustedHostMiddleware, - allowed_hosts=trusted_hosts -) +app.add_middleware(TrustedHostMiddleware, allowed_hosts=trusted_hosts) # Add Prometheus metrics middleware app.add_middleware(PrometheusMiddleware, service_name="openwatch") @@ -351,48 +380,45 @@ async def health_check(): "status": "healthy", "timestamp": time.time(), "version": "1.2.0", - "fips_mode": settings.fips_mode + "fips_mode": settings.fips_mode, } - + # Check database connectivity try: from .database import check_database_health + db_healthy = await check_database_health() health_status["database"] = "healthy" if db_healthy else "unhealthy" except Exception as e: logger.warning(f"Database health check failed: {e}") health_status["database"] = "unknown" db_healthy = True # Continue for development - + # Check Redis connectivity try: from .celery_app import check_redis_health + redis_healthy = await check_redis_health() health_status["redis"] = "healthy" if redis_healthy else "unhealthy" except Exception as e: logger.warning(f"Redis health check failed: {e}") health_status["redis"] = "unknown" redis_healthy = True # Continue for development - + # Overall status if not (db_healthy and redis_healthy): health_status["status"] = "degraded" return JSONResponse( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - content=health_status + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, content=health_status ) - + return health_status - + except Exception as e: logger.error(f"Health check failed: {e}") return JSONResponse( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - content={ - "status": "unhealthy", - "error": str(e), - "timestamp": time.time() - } + content={"status": "unhealthy", "error": str(e), "timestamp": time.time()}, ) @@ -406,7 +432,7 @@ async def security_info(): "jwt_algorithm": "RS256", "encryption": "AES-256-GCM", "hash_algorithm": "Argon2id", - "tls_version": "1.3" + "tls_version": "1.3", } @@ -415,13 +441,12 @@ async def security_info(): async def metrics(): """Prometheus metrics endpoint""" from fastapi.responses import PlainTextResponse - + metrics_instance = get_metrics_instance() metrics_data = metrics_instance.get_metrics() - + return PlainTextResponse( - content=metrics_data, - media_type="text/plain; version=0.0.4; charset=utf-8" + content=metrics_data, media_type="text/plain; version=0.0.4; charset=utf-8" ) @@ -446,7 +471,9 @@ async def metrics(): app.include_router(credentials.router, tags=["Credential Sharing"]) app.include_router(api_keys.router, prefix="/api/api-keys", tags=["API Keys"]) app.include_router(remediation_callback.router, tags=["AEGIS Integration"]) -app.include_router(integration_metrics.router, prefix="/api/integration/metrics", tags=["Integration Metrics"]) +app.include_router( + integration_metrics.router, prefix="/api/integration/metrics", tags=["Integration Metrics"] +) app.include_router(bulk_operations.router, prefix="/api/bulk", tags=["Bulk Operations"]) # app.include_router(terminal.router, tags=["Terminal"]) # Terminal module not available app.include_router(compliance.router, prefix="/api/compliance", tags=["Compliance Intelligence"]) @@ -456,7 +483,7 @@ async def metrics(): if automated_fixes: app.include_router(automated_fixes.router, tags=["Secure Automated Fixes"]) if authorization: - app.include_router(authorization.router, tags=["Authorization Management"]) + app.include_router(authorization.router, tags=["Authorization Management"]) if security_config: app.include_router(security_config.router, tags=["Security Configuration"]) @@ -468,24 +495,19 @@ async def global_exception_handler(request: Request, exc: Exception): client_ip = request.client.host if "x-forwarded-for" in request.headers: client_ip = request.headers["x-forwarded-for"].split(",")[0].strip() - + # Log the exception logger.error(f"Unhandled exception: {exc}", exc_info=True) - + # Log security event audit_logger.log_security_event( - "EXCEPTION", - f"Path: {request.url.path}, Exception: {type(exc).__name__}", - client_ip + "EXCEPTION", f"Path: {request.url.path}, Exception: {type(exc).__name__}", client_ip ) - + # Return generic error response (don't expose internal details) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={ - "detail": "Internal server error", - "error_id": f"{int(time.time())}" - } + content={"detail": "Internal server error", "error_id": f"{int(time.time())}"}, ) @@ -498,5 +520,5 @@ async def global_exception_handler(request: Request, exc: Exception): ssl_keyfile=settings.tls_key_file if settings.require_https else None, ssl_certfile=settings.tls_cert_file if settings.require_https else None, log_level=settings.log_level.lower(), - reload=settings.debug - ) \ No newline at end of file + reload=settings.debug, + ) diff --git a/backend/app/middleware/authorization_middleware.py b/backend/app/middleware/authorization_middleware.py index 176a4094..45b42ff7 100644 --- a/backend/app/middleware/authorization_middleware.py +++ b/backend/app/middleware/authorization_middleware.py @@ -9,13 +9,14 @@ ZERO TRUST IMPLEMENTATION: - All requests validated before processing -- No implicit trust or bypass mechanisms +- No implicit trust or bypass mechanisms - Comprehensive audit trail - Fail-secure behavior on errors - Resource-level permission validation Design by Emily (Security Engineer) & Implementation by Daniel (Backend Engineer) """ + import logging import time import json @@ -28,8 +29,12 @@ from ..services.authorization_service import AuthorizationService, get_authorization_service from ..models.authorization_models import ( - ResourceType, ActionType, ResourceIdentifier, AuthorizationContext, - AuthorizationDecision, BulkAuthorizationRequest + ResourceType, + ActionType, + ResourceIdentifier, + AuthorizationContext, + AuthorizationDecision, + BulkAuthorizationRequest, ) from ..auth import get_current_user from ..database import get_db @@ -40,7 +45,7 @@ class AuthorizationMiddleware(BaseHTTPMiddleware): """ Authorization middleware that validates all requests against resource permissions - + SECURITY FEATURES: 1. Request Interception - Validates all scan-related operations 2. Resource Identification - Extracts resource information from requests @@ -48,312 +53,401 @@ class AuthorizationMiddleware(BaseHTTPMiddleware): 4. Audit Logging - Records all authorization decisions 5. Fail-Secure - Denies access on any error or uncertainty """ - + def __init__(self, app, authorization_service_factory: Callable = None): super().__init__(app) - self.authorization_service_factory = authorization_service_factory or get_authorization_service - + self.authorization_service_factory = ( + authorization_service_factory or get_authorization_service + ) + # Define which endpoints require authorization and what resource/action they map to self.protected_endpoints = { # Scan operations - 'POST /api/v1/scans': {'resource_type': ResourceType.HOST, 'action': ActionType.SCAN, 'bulk': False}, - 'PUT /api/v1/scans/{scan_id}': {'resource_type': ResourceType.SCAN, 'action': ActionType.WRITE, 'bulk': False}, - 'DELETE /api/v1/scans/{scan_id}': {'resource_type': ResourceType.SCAN, 'action': ActionType.DELETE, 'bulk': False}, - 'GET /api/v1/scans/{scan_id}': {'resource_type': ResourceType.SCAN, 'action': ActionType.READ, 'bulk': False}, - 'POST /api/v1/scans/{scan_id}/execute': {'resource_type': ResourceType.SCAN, 'action': ActionType.EXECUTE, 'bulk': False}, - + "POST /api/v1/scans": { + "resource_type": ResourceType.HOST, + "action": ActionType.SCAN, + "bulk": False, + }, + "PUT /api/v1/scans/{scan_id}": { + "resource_type": ResourceType.SCAN, + "action": ActionType.WRITE, + "bulk": False, + }, + "DELETE /api/v1/scans/{scan_id}": { + "resource_type": ResourceType.SCAN, + "action": ActionType.DELETE, + "bulk": False, + }, + "GET /api/v1/scans/{scan_id}": { + "resource_type": ResourceType.SCAN, + "action": ActionType.READ, + "bulk": False, + }, + "POST /api/v1/scans/{scan_id}/execute": { + "resource_type": ResourceType.SCAN, + "action": ActionType.EXECUTE, + "bulk": False, + }, # Bulk scan operations - CRITICAL VULNERABILITY PREVENTION - 'POST /api/v1/bulk-scans': {'resource_type': ResourceType.HOST, 'action': ActionType.SCAN, 'bulk': True}, - 'POST /api/v1/bulk-scans/{session_id}/start': {'resource_type': ResourceType.HOST, 'action': ActionType.EXECUTE, 'bulk': True}, - 'DELETE /api/v1/bulk-scans/{session_id}': {'resource_type': ResourceType.HOST, 'action': ActionType.DELETE, 'bulk': True}, - + "POST /api/v1/bulk-scans": { + "resource_type": ResourceType.HOST, + "action": ActionType.SCAN, + "bulk": True, + }, + "POST /api/v1/bulk-scans/{session_id}/start": { + "resource_type": ResourceType.HOST, + "action": ActionType.EXECUTE, + "bulk": True, + }, + "DELETE /api/v1/bulk-scans/{session_id}": { + "resource_type": ResourceType.HOST, + "action": ActionType.DELETE, + "bulk": True, + }, # Host operations - 'GET /api/v1/hosts/{host_id}': {'resource_type': ResourceType.HOST, 'action': ActionType.READ, 'bulk': False}, - 'PUT /api/v1/hosts/{host_id}': {'resource_type': ResourceType.HOST, 'action': ActionType.WRITE, 'bulk': False}, - 'DELETE /api/v1/hosts/{host_id}': {'resource_type': ResourceType.HOST, 'action': ActionType.DELETE, 'bulk': False}, - 'POST /api/v1/hosts/{host_id}/scan': {'resource_type': ResourceType.HOST, 'action': ActionType.SCAN, 'bulk': False}, - + "GET /api/v1/hosts/{host_id}": { + "resource_type": ResourceType.HOST, + "action": ActionType.READ, + "bulk": False, + }, + "PUT /api/v1/hosts/{host_id}": { + "resource_type": ResourceType.HOST, + "action": ActionType.WRITE, + "bulk": False, + }, + "DELETE /api/v1/hosts/{host_id}": { + "resource_type": ResourceType.HOST, + "action": ActionType.DELETE, + "bulk": False, + }, + "POST /api/v1/hosts/{host_id}/scan": { + "resource_type": ResourceType.HOST, + "action": ActionType.SCAN, + "bulk": False, + }, # Host group operations - 'GET /api/v1/host-groups/{group_id}': {'resource_type': ResourceType.HOST_GROUP, 'action': ActionType.READ, 'bulk': False}, - 'PUT /api/v1/host-groups/{group_id}': {'resource_type': ResourceType.HOST_GROUP, 'action': ActionType.WRITE, 'bulk': False}, - 'DELETE /api/v1/host-groups/{group_id}': {'resource_type': ResourceType.HOST_GROUP, 'action': ActionType.DELETE, 'bulk': False}, - 'POST /api/v1/host-groups/{group_id}/scan': {'resource_type': ResourceType.HOST_GROUP, 'action': ActionType.SCAN, 'bulk': True}, - + "GET /api/v1/host-groups/{group_id}": { + "resource_type": ResourceType.HOST_GROUP, + "action": ActionType.READ, + "bulk": False, + }, + "PUT /api/v1/host-groups/{group_id}": { + "resource_type": ResourceType.HOST_GROUP, + "action": ActionType.WRITE, + "bulk": False, + }, + "DELETE /api/v1/host-groups/{group_id}": { + "resource_type": ResourceType.HOST_GROUP, + "action": ActionType.DELETE, + "bulk": False, + }, + "POST /api/v1/host-groups/{group_id}/scan": { + "resource_type": ResourceType.HOST_GROUP, + "action": ActionType.SCAN, + "bulk": True, + }, # Rule scanning operations - 'POST /api/v1/rule-scan': {'resource_type': ResourceType.HOST, 'action': ActionType.SCAN, 'bulk': True}, - + "POST /api/v1/rule-scan": { + "resource_type": ResourceType.HOST, + "action": ActionType.SCAN, + "bulk": True, + }, # Remediation operations - 'POST /api/v1/scans/{scan_id}/remediate': {'resource_type': ResourceType.SCAN, 'action': ActionType.EXECUTE, 'bulk': False}, - 'POST /api/v1/bulk-remediate': {'resource_type': ResourceType.HOST, 'action': ActionType.EXECUTE, 'bulk': True}, + "POST /api/v1/scans/{scan_id}/remediate": { + "resource_type": ResourceType.SCAN, + "action": ActionType.EXECUTE, + "bulk": False, + }, + "POST /api/v1/bulk-remediate": { + "resource_type": ResourceType.HOST, + "action": ActionType.EXECUTE, + "bulk": True, + }, } - - logger.info(f"Authorization middleware initialized with {len(self.protected_endpoints)} protected endpoints") - - async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + + logger.info( + f"Authorization middleware initialized with {len(self.protected_endpoints)} protected endpoints" + ) + + async def dispatch( + self, request: Request, call_next: Callable[[Request], Awaitable[Response]] + ) -> Response: """ Main middleware dispatch method - validates authorization for protected endpoints """ start_time = time.time() - + try: # Check if this endpoint requires authorization endpoint_pattern = self._get_endpoint_pattern(request) if not endpoint_pattern or endpoint_pattern not in self.protected_endpoints: # Not a protected endpoint, pass through return await call_next(request) - + logger.debug(f"Authorizing request: {request.method} {request.url.path}") - + # Get current user from token current_user = await self._extract_current_user(request) if not current_user: - logger.warning(f"Authorization failed: No authenticated user for {request.method} {request.url.path}") + logger.warning( + f"Authorization failed: No authenticated user for {request.method} {request.url.path}" + ) return self._create_error_response( - status.HTTP_401_UNAUTHORIZED, - "Authentication required", - request.url.path + status.HTTP_401_UNAUTHORIZED, "Authentication required", request.url.path ) - + # Get endpoint configuration endpoint_config = self.protected_endpoints[endpoint_pattern] - + # Extract resources from request resources = await self._extract_resources(request, endpoint_config, current_user) if not resources: - logger.warning(f"Authorization failed: Could not extract resources from request {request.method} {request.url.path}") + logger.warning( + f"Authorization failed: Could not extract resources from request {request.method} {request.url.path}" + ) return self._create_error_response( status.HTTP_400_BAD_REQUEST, "Could not determine resources for authorization", - request.url.path + request.url.path, ) - + # Create authorization context auth_context = await self._build_authorization_context(request, current_user) - + # Perform authorization check authorization_result = await self._perform_authorization_check( - current_user['id'], + current_user["id"], resources, - endpoint_config['action'], - endpoint_config['bulk'], - auth_context + endpoint_config["action"], + endpoint_config["bulk"], + auth_context, ) - + if authorization_result.overall_decision != AuthorizationDecision.ALLOW: - logger.warning(f"Authorization denied for user {current_user['id']} on {request.method} {request.url.path}: " - f"{len(authorization_result.denied_resources)} resources denied") - return self._create_authorization_error_response(authorization_result, request.url.path) - + logger.warning( + f"Authorization denied for user {current_user['id']} on {request.method} {request.url.path}: " + f"{len(authorization_result.denied_resources)} resources denied" + ) + return self._create_authorization_error_response( + authorization_result, request.url.path + ) + # Authorization successful - add context to request for downstream use request.state.authorization_result = authorization_result request.state.current_user = current_user - + # Process the request response = await call_next(request) - + # Log successful authorization processing_time = int((time.time() - start_time) * 1000) - logger.info(f"Authorization successful for user {current_user['id']} on {request.method} {request.url.path} " - f"({len(authorization_result.allowed_resources)} resources, {processing_time}ms)") - + logger.info( + f"Authorization successful for user {current_user['id']} on {request.method} {request.url.path} " + f"({len(authorization_result.allowed_resources)} resources, {processing_time}ms)" + ) + return response - + except Exception as e: logger.error(f"Authorization middleware error: {e}") - + # Fail securely - deny access on any error return self._create_error_response( status.HTTP_500_INTERNAL_SERVER_ERROR, "Authorization system error", - request.url.path if hasattr(request, 'url') else "unknown" + request.url.path if hasattr(request, "url") else "unknown", ) - + def _get_endpoint_pattern(self, request: Request) -> Optional[str]: """ Match request path to endpoint pattern for authorization """ method = request.method path = request.url.path - + # Direct match first direct_pattern = f"{method} {path}" if direct_pattern in self.protected_endpoints: return direct_pattern - + # Pattern matching with path parameters for pattern in self.protected_endpoints.keys(): if self._match_pattern(f"{method} {path}", pattern): return pattern - + return None - + def _match_pattern(self, request_path: str, pattern: str) -> bool: """ Match request path against pattern with path parameters """ - request_parts = request_path.split('/') - pattern_parts = pattern.split('/') - + request_parts = request_path.split("/") + pattern_parts = pattern.split("/") + if len(request_parts) != len(pattern_parts): return False - + for req_part, pat_part in zip(request_parts, pattern_parts): - if pat_part.startswith('{') and pat_part.endswith('}'): + if pat_part.startswith("{") and pat_part.endswith("}"): # Path parameter - matches anything continue elif req_part != pat_part: return False - + return True - - async def _extract_current_user(self, request: Request) -> Optional[Dict[str, Any]]: + + def _extract_current_user(self, request: Request) -> Optional[Dict[str, Any]]: """ Extract current user from request authentication """ try: # Check for Authorization header - auth_header = request.headers.get('Authorization') - if not auth_header or not auth_header.startswith('Bearer '): + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): return None - - token = auth_header.split(' ', 1)[1] - + + token = auth_header.split(" ", 1)[1] + # Use existing auth system to validate token # This would integrate with your JWT validation logic from ..auth import decode_token - + payload = decode_token(token) if not payload: return None - + # Get user details from database db = next(get_db()) try: from sqlalchemy import text - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT id, username, email, role, is_active FROM users WHERE id = :user_id AND is_active = true - """), {"user_id": payload.get('sub')}) - + """ + ), + {"user_id": payload.get("sub")}, + ) + user_row = result.fetchone() if not user_row: return None - + return { - 'id': str(user_row.id), - 'username': user_row.username, - 'email': user_row.email, - 'role': user_row.role, - 'is_active': user_row.is_active + "id": str(user_row.id), + "username": user_row.username, + "email": user_row.email, + "role": user_row.role, + "is_active": user_row.is_active, } finally: db.close() - + except Exception as e: logger.error(f"Error extracting current user: {e}") return None - + async def _extract_resources( - self, - request: Request, - endpoint_config: Dict[str, Any], - current_user: Dict[str, Any] + self, request: Request, endpoint_config: Dict[str, Any], current_user: Dict[str, Any] ) -> List[ResourceIdentifier]: """ Extract resource identifiers from request based on endpoint configuration """ try: resources = [] - resource_type = endpoint_config['resource_type'] - is_bulk = endpoint_config['bulk'] - + resource_type = endpoint_config["resource_type"] + is_bulk = endpoint_config["bulk"] + # Extract path parameters path_params = self._extract_path_params(request) - + if not is_bulk: # Single resource operation if resource_type == ResourceType.HOST: - host_id = path_params.get('host_id') + host_id = path_params.get("host_id") if host_id: - resources.append(ResourceIdentifier( - resource_type=ResourceType.HOST, - resource_id=host_id - )) + resources.append( + ResourceIdentifier(resource_type=ResourceType.HOST, resource_id=host_id) + ) else: # Check request body for host_id body_host_id = await self._extract_host_id_from_body(request) if body_host_id: - resources.append(ResourceIdentifier( - resource_type=ResourceType.HOST, - resource_id=body_host_id - )) - + resources.append( + ResourceIdentifier( + resource_type=ResourceType.HOST, resource_id=body_host_id + ) + ) + elif resource_type == ResourceType.SCAN: - scan_id = path_params.get('scan_id') + scan_id = path_params.get("scan_id") if scan_id: # For scan operations, we need to get the associated host host_id = await self._get_host_id_from_scan_id(scan_id) if host_id: - resources.append(ResourceIdentifier( - resource_type=ResourceType.HOST, - resource_id=host_id - )) - + resources.append( + ResourceIdentifier( + resource_type=ResourceType.HOST, resource_id=host_id + ) + ) + elif resource_type == ResourceType.HOST_GROUP: - group_id = path_params.get('group_id') + group_id = path_params.get("group_id") if group_id: # For host group operations, get all hosts in the group host_ids = await self._get_host_ids_from_group_id(group_id) for host_id in host_ids: - resources.append(ResourceIdentifier( - resource_type=ResourceType.HOST, - resource_id=host_id - )) - + resources.append( + ResourceIdentifier( + resource_type=ResourceType.HOST, resource_id=host_id + ) + ) + else: # Bulk operation - extract multiple hosts host_ids = await self._extract_bulk_host_ids(request, endpoint_config) for host_id in host_ids: - resources.append(ResourceIdentifier( - resource_type=ResourceType.HOST, - resource_id=host_id - )) - + resources.append( + ResourceIdentifier(resource_type=ResourceType.HOST, resource_id=host_id) + ) + logger.debug(f"Extracted {len(resources)} resources from request") return resources - + except Exception as e: logger.error(f"Error extracting resources: {e}") return [] - + def _extract_path_params(self, request: Request) -> Dict[str, str]: """ Extract path parameters from request URL """ path_params = {} - + # Get path parameters from FastAPI's path_params if available - if hasattr(request, 'path_params'): + if hasattr(request, "path_params"): path_params.update(request.path_params) - + return path_params - + async def _extract_host_id_from_body(self, request: Request) -> Optional[str]: """ Extract host_id from request body """ try: - if request.headers.get('content-type', '').startswith('application/json'): + if request.headers.get("content-type", "").startswith("application/json"): body = await request.body() if body: data = json.loads(body) - return data.get('host_id') + return data.get("host_id") except Exception as e: logger.error(f"Error extracting host_id from body: {e}") - + return None - - async def _get_host_id_from_scan_id(self, scan_id: str) -> Optional[str]: + + def _get_host_id_from_scan_id(self, scan_id: str) -> Optional[str]: """ Get host_id associated with a scan_id """ @@ -361,10 +455,16 @@ async def _get_host_id_from_scan_id(self, scan_id: str) -> Optional[str]: db = next(get_db()) try: from sqlalchemy import text - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT host_id FROM scans WHERE id = :scan_id - """), {"scan_id": scan_id}) - + """ + ), + {"scan_id": scan_id}, + ) + row = result.fetchone() return str(row.host_id) if row else None finally: @@ -372,8 +472,8 @@ async def _get_host_id_from_scan_id(self, scan_id: str) -> Optional[str]: except Exception as e: logger.error(f"Error getting host_id from scan_id {scan_id}: {e}") return None - - async def _get_host_ids_from_group_id(self, group_id: str) -> List[str]: + + def _get_host_ids_from_group_id(self, group_id: str) -> List[str]: """ Get all host_ids in a host group """ @@ -381,53 +481,55 @@ async def _get_host_ids_from_group_id(self, group_id: str) -> List[str]: db = next(get_db()) try: from sqlalchemy import text - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT hgm.host_id FROM host_group_memberships hgm WHERE hgm.group_id = :group_id - """), {"group_id": group_id}) - + """ + ), + {"group_id": group_id}, + ) + return [str(row.host_id) for row in result] finally: db.close() except Exception as e: logger.error(f"Error getting host_ids from group_id {group_id}: {e}") return [] - + async def _extract_bulk_host_ids( - self, - request: Request, - endpoint_config: Dict[str, Any] + self, request: Request, endpoint_config: Dict[str, Any] ) -> List[str]: """ Extract host IDs for bulk operations from request body """ try: - if request.headers.get('content-type', '').startswith('application/json'): + if request.headers.get("content-type", "").startswith("application/json"): body = await request.body() if body: data = json.loads(body) - + # Different bulk operations have different request structures - if 'host_ids' in data: - return data['host_ids'] - elif 'hosts' in data: - return [host.get('id') for host in data['hosts'] if host.get('id')] - elif 'host_id' in data: + if "host_ids" in data: + return data["host_ids"] + elif "hosts" in data: + return [host.get("id") for host in data["hosts"] if host.get("id")] + elif "host_id" in data: # Single host in bulk format - return [data['host_id']] - elif 'target_hosts' in data: - return data['target_hosts'] - + return [data["host_id"]] + elif "target_hosts" in data: + return data["target_hosts"] + except Exception as e: logger.error(f"Error extracting bulk host IDs: {e}") - + return [] - - async def _build_authorization_context( - self, - request: Request, - current_user: Dict[str, Any] + + def _build_authorization_context( + self, request: Request, current_user: Dict[str, Any] ) -> AuthorizationContext: """ Build authorization context from request and user information @@ -437,7 +539,10 @@ async def _build_authorization_context( db = next(get_db()) try: from sqlalchemy import text - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT COALESCE( JSON_AGG(DISTINCT ug.name) FILTER (WHERE ug.name IS NOT NULL), '[]'::json @@ -447,56 +552,59 @@ async def _build_authorization_context( LEFT JOIN user_groups ug ON ugm.group_id = ug.id WHERE u.id = :user_id GROUP BY u.id - """), {"user_id": current_user['id']}) - + """ + ), + {"user_id": current_user["id"]}, + ) + row = result.fetchone() user_groups = json.loads(row.user_groups) if row and row.user_groups else [] finally: db.close() - + return AuthorizationContext( - user_id=current_user['id'], - user_roles=[current_user['role']] if current_user.get('role') else [], + user_id=current_user["id"], + user_roles=[current_user["role"]] if current_user.get("role") else [], user_groups=user_groups, ip_address=self._get_client_ip(request), - user_agent=request.headers.get('user-agent'), - session_id=request.headers.get('x-session-id') + user_agent=request.headers.get("user-agent"), + session_id=request.headers.get("x-session-id"), ) - + except Exception as e: logger.error(f"Error building authorization context: {e}") return AuthorizationContext( - user_id=current_user['id'], - user_roles=[current_user.get('role', 'guest')], - user_groups=[] + user_id=current_user["id"], + user_roles=[current_user.get("role", "guest")], + user_groups=[], ) - + def _get_client_ip(self, request: Request) -> str: """ Get client IP address from request """ # Check for forwarded headers first (behind proxy) - forwarded_for = request.headers.get('x-forwarded-for') + forwarded_for = request.headers.get("x-forwarded-for") if forwarded_for: - return forwarded_for.split(',')[0].strip() - - real_ip = request.headers.get('x-real-ip') + return forwarded_for.split(",")[0].strip() + + real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip - + # Fallback to client IP - if hasattr(request, 'client') and request.client: + if hasattr(request, "client") and request.client: return request.client.host - + return "unknown" - + async def _perform_authorization_check( self, user_id: str, resources: List[ResourceIdentifier], action: ActionType, is_bulk: bool, - context: AuthorizationContext + context: AuthorizationContext, ): """ Perform the actual authorization check using the authorization service @@ -505,25 +613,34 @@ async def _perform_authorization_check( db = next(get_db()) try: auth_service = self.authorization_service_factory(db) - + if len(resources) == 1 and not is_bulk: # Single resource check result = await auth_service.check_permission( user_id, resources[0], action, context ) - + # Convert single result to bulk result format from ..models.authorization_models import BulkAuthorizationResult + return BulkAuthorizationResult( overall_decision=result.decision, individual_results=[result], - denied_resources=[result.resource] if result.decision == AuthorizationDecision.DENY else [], - allowed_resources=[result.resource] if result.decision == AuthorizationDecision.ALLOW else [], + denied_resources=( + [result.resource] + if result.decision == AuthorizationDecision.DENY + else [] + ), + allowed_resources=( + [result.resource] + if result.decision == AuthorizationDecision.ALLOW + else [] + ), total_evaluation_time_ms=result.evaluation_time_ms, cached_results=1 if result.cached else 0, - fresh_evaluations=0 if result.cached else 1 + fresh_evaluations=0 if result.cached else 1, ) - + else: # Bulk authorization check - CRITICAL SECURITY IMPLEMENTATION bulk_request = BulkAuthorizationRequest( @@ -532,19 +649,20 @@ async def _perform_authorization_check( action=action, context=context, fail_fast=True, # Stop on first denial for security - parallel_evaluation=True # Enable parallel processing for performance + parallel_evaluation=True, # Enable parallel processing for performance ) - + return await auth_service.check_bulk_permissions(bulk_request) - + finally: db.close() - + except Exception as e: logger.error(f"Authorization check failed: {e}") - + # Fail securely from ..models.authorization_models import BulkAuthorizationResult + return BulkAuthorizationResult( overall_decision=AuthorizationDecision.DENY, individual_results=[], @@ -552,9 +670,9 @@ async def _perform_authorization_check( allowed_resources=[], total_evaluation_time_ms=0, cached_results=0, - fresh_evaluations=0 + fresh_evaluations=0, ) - + def _create_error_response(self, status_code: int, message: str, path: str) -> JSONResponse: """ Create standardized error response @@ -565,10 +683,10 @@ def _create_error_response(self, status_code: int, message: str, path: str) -> J "error": message, "path": path, "timestamp": datetime.utcnow().isoformat(), - "type": "authorization_error" - } + "type": "authorization_error", + }, ) - + def _create_authorization_error_response(self, auth_result, path: str) -> JSONResponse: """ Create detailed authorization error response @@ -578,13 +696,17 @@ def _create_authorization_error_response(self, auth_result, path: str) -> JSONRe "resource_type": res.resource_type.value, "resource_id": res.resource_id, "reason": next( - (r.reason for r in auth_result.individual_results if r.resource.resource_id == res.resource_id), - "Access denied" - ) + ( + r.reason + for r in auth_result.individual_results + if r.resource.resource_id == res.resource_id + ), + "Access denied", + ), } for res in auth_result.denied_resources ] - + return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={ @@ -593,8 +715,8 @@ def _create_authorization_error_response(self, auth_result, path: str) -> JSONRe "denied_resources": denied_resources, "path": path, "timestamp": datetime.utcnow().isoformat(), - "type": "authorization_denied" - } + "type": "authorization_denied", + }, ) @@ -603,4 +725,4 @@ def create_authorization_middleware(app, authorization_service_factory: Callable """ Factory function to create authorization middleware instance """ - return AuthorizationMiddleware(app, authorization_service_factory) \ No newline at end of file + return AuthorizationMiddleware(app, authorization_service_factory) diff --git a/backend/app/middleware/metrics.py b/backend/app/middleware/metrics.py index 8529c1ef..129115ad 100644 --- a/backend/app/middleware/metrics.py +++ b/backend/app/middleware/metrics.py @@ -3,6 +3,7 @@ Automatic metrics collection for HTTP requests and application events Author: Noah Chen - nc9010@hanalyx.com """ + import time import logging from fastapi import Request, Response @@ -17,47 +18,47 @@ class PrometheusMiddleware(BaseHTTPMiddleware): """ Middleware to automatically collect Prometheus metrics for HTTP requests """ - + def __init__(self, app, service_name: str = "openwatch"): super().__init__(app) self.service_name = service_name self.metrics = get_metrics_instance() - + # Set service as up self.metrics.set_service_up(service_name, True) - + async def dispatch(self, request: Request, call_next: Callable) -> Response: """Process request and collect metrics""" start_time = time.time() - + # Extract path and method path = request.url.path method = request.method - + # Normalize endpoint for metrics (remove dynamic parts) endpoint = self._normalize_endpoint(path) - + try: # Process request response = await call_next(request) - + # Calculate duration duration = time.time() - start_time - + # Record metrics self.metrics.record_http_request( method=method, endpoint=endpoint, status_code=response.status_code, duration=duration, - service=self.service_name + service=self.service_name, ) - + # Record specific application events await self._record_application_metrics(request, response, duration) - + return response - + except Exception as e: # Record error metrics duration = time.time() - start_time @@ -66,15 +67,15 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: endpoint=endpoint, status_code=500, duration=duration, - service=self.service_name + service=self.service_name, ) - + # Record security event for unhandled exceptions self.metrics.record_security_event("HTTP_EXCEPTION", "high") - + # Re-raise the exception raise e - + def _normalize_endpoint(self, path: str) -> str: """ Normalize endpoint path for metrics collection @@ -83,64 +84,64 @@ def _normalize_endpoint(self, path: str) -> str: # Common patterns to normalize normalizations = [ # UUIDs and IDs - (r'/[a-f0-9-]{36}', '/{uuid}'), - (r'/\d+', '/{id}'), + (r"/[a-f0-9-]{36}", "/{uuid}"), + (r"/\d+", "/{id}"), # Remove query parameters - (r'\?.*', ''), + (r"\?.*", ""), ] - + normalized_path = path import re - + for pattern, replacement in normalizations: normalized_path = re.sub(pattern, replacement, normalized_path) - + # Limit path length and complexity if len(normalized_path) > 100: normalized_path = normalized_path[:97] + "..." - + return normalized_path - - async def _record_application_metrics(self, request: Request, response: Response, duration: float): + + def _record_application_metrics(self, request: Request, response: Response, duration: float): """Record application-specific metrics based on request/response""" path = request.url.path - + try: # Authentication metrics - if path.startswith('/api/auth'): + if path.startswith("/api/auth"): if response.status_code == 200: self.metrics.record_authentication_attempt("success") elif response.status_code in [401, 403]: self.metrics.record_authentication_attempt("failure") self.metrics.record_security_event("AUTH_FAILURE", "medium") - + # Scan operation metrics - elif path.startswith('/api/scans'): + elif path.startswith("/api/scans"): if request.method == "POST" and response.status_code == 201: # Scan initiated pass # Detailed metrics will be recorded by scan service elif request.method == "GET" and "results" in path: # Scan results accessed pass - + # Host management metrics - elif path.startswith('/api/hosts'): + elif path.startswith("/api/hosts"): if request.method == "POST" and response.status_code == 201: # New host added pass elif request.method == "DELETE" and response.status_code == 200: # Host removed pass - + # Integration metrics - elif path.startswith('/api/v1/webhooks') or 'integration' in path: + elif path.startswith("/api/v1/webhooks") or "integration" in path: self.metrics.record_integration_call( target="webhook", endpoint=path, status="success" if response.status_code < 400 else "error", - duration=duration + duration=duration, ) - + # Security events for suspicious activity if response.status_code == 401: self.metrics.record_security_event("UNAUTHORIZED_ACCESS", "medium") @@ -150,22 +151,23 @@ async def _record_application_metrics(self, request: Request, response: Response self.metrics.record_security_event("RATE_LIMIT_EXCEEDED", "low") elif response.status_code >= 500: self.metrics.record_security_event("SERVER_ERROR", "high") - + except Exception as e: logger.error(f"Error recording application metrics: {e}") class DatabaseMetricsCollector: """Collector for database-related metrics""" - + def __init__(self): self.metrics = get_metrics_instance() - - async def record_query_metrics(self, operation: str, duration: float): + + def record_query_metrics(self, operation: str, duration: float): """Record database query metrics""" from ..services.prometheus_metrics import database_query_duration_seconds + database_query_duration_seconds.labels(operation=operation).observe(duration) - + async def update_connection_metrics(self, db): """Update database connection metrics""" await self.metrics.update_database_metrics(db) @@ -173,77 +175,85 @@ async def update_connection_metrics(self, db): class BackgroundMetricsUpdater: """Background task to update system and application metrics""" - + def __init__(self): self.metrics = get_metrics_instance() self.is_running = False - + async def start_background_updates(self): """Start background metrics collection""" if self.is_running: return - + self.is_running = True import asyncio - + while self.is_running: try: # Update system metrics self.metrics.update_system_metrics() - + # Update application-specific metrics await self._update_application_metrics() - + # Wait 30 seconds before next update await asyncio.sleep(30) - + except Exception as e: logger.error(f"Error in background metrics update: {e}") await asyncio.sleep(60) # Wait longer on error - + async def _update_application_metrics(self): """Update application-specific metrics""" try: from ..database import get_db from sqlalchemy import text - + db = next(get_db()) - + try: # Update host counts - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT status, COUNT(*) as count FROM hosts WHERE is_active = true GROUP BY status - """)) - + """ + ) + ) + status_counts = {} for row in result: status_counts[row.status] = row.count - + self.metrics.update_host_counts(status_counts) - + # Update active scans count - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT COUNT(*) as active_scans FROM scans WHERE status IN ('running', 'pending') - """)) - + """ + ) + ) + row = result.fetchone() if row: self.metrics.set_active_scans(row.active_scans) - + # Update database metrics await self.metrics.update_database_metrics(db) - + finally: db.close() - + except Exception as e: logger.error(f"Error updating application metrics: {e}") - + def stop_background_updates(self): """Stop background metrics collection""" self.is_running = False @@ -251,4 +261,4 @@ def stop_background_updates(self): # Global instances db_metrics = DatabaseMetricsCollector() -background_updater = BackgroundMetricsUpdater() \ No newline at end of file +background_updater = BackgroundMetricsUpdater() diff --git a/backend/app/middleware/rate_limiting.py b/backend/app/middleware/rate_limiting.py index d8c88db8..cdbb6b48 100644 --- a/backend/app/middleware/rate_limiting.py +++ b/backend/app/middleware/rate_limiting.py @@ -2,6 +2,7 @@ OpenWatch Rate Limiting Middleware Implements industry-standard rate limiting with token bucket algorithm """ + import os import time import hashlib @@ -17,41 +18,43 @@ logger = logging.getLogger(__name__) audit_logger = get_security_audit_logger() + @dataclass class TokenBucket: """Token bucket implementation for smooth rate limiting""" + capacity: int # Maximum tokens tokens: float # Current tokens - rate: float # Tokens per second + rate: float # Tokens per second last_update: float # Last update timestamp - + def consume(self, tokens_requested: int = 1) -> bool: """Attempt to consume tokens from bucket""" now = time.time() - + # Add tokens based on elapsed time elapsed = now - self.last_update self.tokens = min(self.capacity, self.tokens + (elapsed * self.rate)) self.last_update = now - + # Check if we have enough tokens if self.tokens >= tokens_requested: self.tokens -= tokens_requested return True return False - + def time_until_available(self, tokens_needed: int = 1) -> float: """Calculate seconds until tokens are available""" if self.tokens >= tokens_needed: return 0.0 - + tokens_deficit = tokens_needed - self.tokens return tokens_deficit / self.rate class RateLimitStore: """In-memory rate limit store with token buckets and automatic cleanup""" - + def __init__(self): # Token buckets per client/endpoint self.buckets: Dict[str, TokenBucket] = {} @@ -59,7 +62,7 @@ def __init__(self): self.suspicious_activity: Dict[str, Dict[str, int]] = {} # Last cleanup time self.last_cleanup = time.time() - + def get_or_create_bucket(self, bucket_key: str, capacity: int, rate: float) -> TokenBucket: """Get or create token bucket for client""" if bucket_key not in self.buckets: @@ -67,71 +70,73 @@ def get_or_create_bucket(self, bucket_key: str, capacity: int, rate: float) -> T capacity=capacity, tokens=capacity, # Start with full bucket rate=rate, - last_update=time.time() + last_update=time.time(), ) return self.buckets[bucket_key] - + def track_suspicious_activity(self, client_id: str, activity_type: str): """Track suspicious activity patterns""" if client_id not in self.suspicious_activity: self.suspicious_activity[client_id] = {} - + current_minute = int(time.time() // 60) key = f"{activity_type}:{current_minute}" - + if key not in self.suspicious_activity[client_id]: self.suspicious_activity[client_id][key] = 0 - + self.suspicious_activity[client_id][key] += 1 - - def get_suspicious_activity_count(self, client_id: str, activity_type: str, minutes: int = 1) -> int: + + def get_suspicious_activity_count( + self, client_id: str, activity_type: str, minutes: int = 1 + ) -> int: """Get count of suspicious activities within time window""" if client_id not in self.suspicious_activity: return 0 - + current_minute = int(time.time() // 60) count = 0 - + for i in range(minutes): key = f"{activity_type}:{current_minute - i}" count += self.suspicious_activity[client_id].get(key, 0) - + return count - + def cleanup_old_entries(self): """Clean up old entries to prevent memory bloat""" if time.time() - self.last_cleanup < 300: # Every 5 minutes return - + now = time.time() cleanup_age = 3600 # Remove buckets unused for 1 hour - + # Clean up old token buckets buckets_to_remove = [ - key for key, bucket in self.buckets.items() - if now - bucket.last_update > cleanup_age + key for key, bucket in self.buckets.items() if now - bucket.last_update > cleanup_age ] - + for key in buckets_to_remove: del self.buckets[key] - + # Clean up old suspicious activity data current_minute = int(now // 60) cutoff_minute = current_minute - 120 # Keep 2 hours of data - + for client_id in list(self.suspicious_activity.keys()): activities = self.suspicious_activity[client_id] old_keys = [ - key for key in activities.keys() - if ':' in key and int(key.split(':')[1]) < cutoff_minute + key + for key in activities.keys() + if ":" in key and int(key.split(":")[1]) < cutoff_minute ] - + for key in old_keys: del activities[key] - + if not activities: del self.suspicious_activity[client_id] - + self.last_cleanup = now if buckets_to_remove: logger.debug(f"Rate limit cleanup: removed {len(buckets_to_remove)} unused buckets") @@ -139,128 +144,127 @@ def cleanup_old_entries(self): class RateLimitingMiddleware: """Industry-standard rate limiting middleware with token bucket algorithm""" - + def __init__(self): self.store = RateLimitStore() self.enabled = os.getenv("OPENWATCH_RATE_LIMITING", "true").lower() == "true" self.environment = os.getenv("OPENWATCH_ENVIRONMENT", "development").lower() self.limits_config = self._get_limits_configuration() - - logger.info(f"Rate limiting initialized - Environment: {self.environment}, Enabled: {self.enabled}") - + + logger.info( + f"Rate limiting initialized - Environment: {self.environment}, Enabled: {self.enabled}" + ) + def _get_limits_configuration(self) -> Dict: """Get rate limits following industry patterns""" base_config = { # Anonymous users (like GitHub's unauthenticated API) - 'anonymous': { - 'requests_per_minute': 60, # 1 per second average - 'burst_capacity': 20, # Allow short bursts - 'retry_after_seconds': 60 # 1 minute recovery + "anonymous": { + "requests_per_minute": 60, # 1 per second average + "burst_capacity": 20, # Allow short bursts + "retry_after_seconds": 60, # 1 minute recovery }, - # Authenticated users (like GitHub's authenticated API) - 'authenticated': { - 'requests_per_minute': 300, # 5 per second average - 'burst_capacity': 100, # Generous burst allowance - 'retry_after_seconds': 30 # 30 second recovery + "authenticated": { + "requests_per_minute": 300, # 5 per second average + "burst_capacity": 100, # Generous burst allowance + "retry_after_seconds": 30, # 30 second recovery }, - # System/health endpoints (like AWS health checks) - 'system': { - 'requests_per_minute': 600, # High limit for monitoring - 'burst_capacity': 200, # Large burst for health checks - 'retry_after_seconds': 10 # Quick recovery + "system": { + "requests_per_minute": 600, # High limit for monitoring + "burst_capacity": 200, # Large burst for health checks + "retry_after_seconds": 10, # Quick recovery }, - # Authentication endpoints (like Stripe's sensitive endpoints) - 'auth': { - 'requests_per_minute': 30, # More restrictive - 'burst_capacity': 10, # Small burst allowance - 'retry_after_seconds': 120 # 2 minute recovery for security + "auth": { + "requests_per_minute": 30, # More restrictive + "burst_capacity": 10, # Small burst allowance + "retry_after_seconds": 120, # 2 minute recovery for security }, - # Error-prone endpoints - 'error_endpoints': { - 'requests_per_minute': 50, - 'burst_capacity': 15, - 'retry_after_seconds': 60 + "error_endpoints": { + "requests_per_minute": 50, + "burst_capacity": 15, + "retry_after_seconds": 60, }, - # Validation endpoints - 'validation': { - 'requests_per_minute': 60, - 'burst_capacity': 20, - 'retry_after_seconds': 60 - } + "validation": { + "requests_per_minute": 60, + "burst_capacity": 20, + "retry_after_seconds": 60, + }, } - + # Environment-specific adjustments if self.environment == "development": # Much higher limits for development for category in base_config: - base_config[category]['requests_per_minute'] *= 10 - base_config[category]['burst_capacity'] *= 5 - base_config[category]['retry_after_seconds'] = min(30, base_config[category]['retry_after_seconds']) - + base_config[category]["requests_per_minute"] *= 10 + base_config[category]["burst_capacity"] *= 5 + base_config[category]["retry_after_seconds"] = min( + 30, base_config[category]["retry_after_seconds"] + ) + elif self.environment == "testing": # Lower limits for testing rate limiting for category in base_config: - base_config[category]['requests_per_minute'] //= 2 - base_config[category]['retry_after_seconds'] = 30 - + base_config[category]["requests_per_minute"] //= 2 + base_config[category]["retry_after_seconds"] = 30 + elif self.environment == "staging": # Slightly higher limits than production for category in base_config: - base_config[category]['requests_per_minute'] = int(base_config[category]['requests_per_minute'] * 1.2) - + base_config[category]["requests_per_minute"] = int( + base_config[category]["requests_per_minute"] * 1.2 + ) + return base_config - + # Suspicious behavior patterns SUSPICIOUS_PATTERNS = { - 'high_error_rate': {'threshold': 30, 'window_minutes': 1}, - 'validation_farming': {'threshold': 20, 'window_minutes': 1}, - 'auth_brute_force': {'threshold': 5, 'window_minutes': 1} + "high_error_rate": {"threshold": 30, "window_minutes": 1}, + "validation_farming": {"threshold": 20, "window_minutes": 1}, + "auth_brute_force": {"threshold": 5, "window_minutes": 1}, } - + async def __call__(self, request: Request, call_next) -> Response: """Main rate limiting middleware function""" # Skip if disabled if not self.enabled: return await call_next(request) - + # Periodic cleanup self.store.cleanup_old_entries() - + # Get client information client_id, client_type = self._get_client_identifier(request) endpoint = str(request.url.path) endpoint_category = self._get_endpoint_category(endpoint) - + # Skip excluded endpoints - if endpoint_category == 'excluded': + if endpoint_category == "excluded": return await call_next(request) - + # Get appropriate configuration config_key = endpoint_category if endpoint_category in self.limits_config else client_type - config = self.limits_config.get(config_key, self.limits_config['anonymous']) - + config = self.limits_config.get(config_key, self.limits_config["anonymous"]) + # Get or create token bucket bucket_key = f"{client_id}:{endpoint_category}" - rate_per_second = config['requests_per_minute'] / 60.0 + rate_per_second = config["requests_per_minute"] / 60.0 bucket = self.store.get_or_create_bucket( - bucket_key, - config['burst_capacity'], - rate_per_second + bucket_key, config["burst_capacity"], rate_per_second ) - + # Create headers for response headers = self._create_rate_limit_headers(bucket, config) - + # Try to consume token if bucket.consume(1): # Track suspicious patterns self._track_suspicious_patterns(client_id, endpoint, request) - + # Check for suspicious behavior suspicious_behavior = self._detect_suspicious_behavior(client_id) if suspicious_behavior: @@ -269,36 +273,37 @@ async def __call__(self, request: Request, call_next) -> Response: source_ip=client_ip, suspicious_patterns=suspicious_behavior, user_id=self._get_user_id(request), - session_id=self._get_session_id(request) + session_id=self._get_session_id(request), ) - + # Request allowed - proceed response = await call_next(request) - + # Add rate limit headers for header_name, header_value in headers.items(): response.headers[header_name] = header_value - + return response else: # Rate limit exceeded retry_after = min( - config['retry_after_seconds'], - int(bucket.time_until_available(1)) + 1 + config["retry_after_seconds"], int(bucket.time_until_available(1)) + 1 ) - + client_ip = self._get_client_ip(request) audit_logger.log_rate_limit_event( source_ip=client_ip, error_count=int(bucket.capacity - bucket.tokens), action_taken=f"rate_limited_retry_after_{retry_after}s", - user_id=self._get_user_id(request) + user_id=self._get_user_id(request), ) - - logger.warning(f"Rate limit exceeded for {client_id} on {endpoint} - retry after {retry_after}s") - + + logger.warning( + f"Rate limit exceeded for {client_id} on {endpoint} - retry after {retry_after}s" + ) + return self._create_rate_limit_response(retry_after, headers) - + def _get_client_identifier(self, request: Request) -> Tuple[str, str]: """Get client identifier and type (anonymous/authenticated)""" # Check for authentication @@ -307,124 +312,126 @@ def _get_client_identifier(self, request: Request) -> Tuple[str, str]: # Authenticated user - use token hash as identifier token_hash = hashlib.sha256(auth_header.encode()).hexdigest()[:16] return f"auth:{token_hash}", "authenticated" - + # Anonymous user - use IP address client_ip = self._get_client_ip(request) ip_hash = hashlib.sha256(f"{client_ip}:anonymous".encode()).hexdigest()[:16] return f"anon:{ip_hash}", "anonymous" - + def _get_client_ip(self, request: Request) -> str: """Extract client IP handling proxy headers""" forwarded_for = request.headers.get("x-forwarded-for") if forwarded_for: return forwarded_for.split(",")[0].strip() - + real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip - + return request.client.host if request.client else "unknown" - + def _get_endpoint_category(self, path: str) -> str: """Categorize endpoint for appropriate rate limiting""" path_lower = path.lower() - + # Excluded endpoints - if any(path.startswith(p) for p in ['/health', '/metrics', '/docs', '/redoc', '/openapi.json', '/security-info']): - return 'excluded' - + if any( + path.startswith(p) + for p in ["/health", "/metrics", "/docs", "/redoc", "/openapi.json", "/security-info"] + ): + return "excluded" + # System endpoints - if any(p in path_lower for p in ['/health', '/metrics']): - return 'system' - + if any(p in path_lower for p in ["/health", "/metrics"]): + return "system" + # Authentication endpoints - if any(p in path_lower for p in ['/auth/', '/login', '/token', '/register', '/mfa']): - return 'auth' - + if any(p in path_lower for p in ["/auth/", "/login", "/token", "/register", "/mfa"]): + return "auth" + # Validation endpoints - if 'validate' in path_lower: - return 'validation' - + if "validate" in path_lower: + return "validation" + # Error/debug endpoints - if any(p in path_lower for p in ['/error', '/debug', '/classify']): - return 'error_endpoints' - + if any(p in path_lower for p in ["/error", "/debug", "/classify"]): + return "error_endpoints" + # Default to regular API - return 'api' - + return "api" + def _create_rate_limit_headers(self, bucket: TokenBucket, config: Dict) -> Dict[str, str]: """Create industry-standard rate limit headers""" current_minute = int(time.time() // 60) reset_time = (current_minute + 1) * 60 # Next minute boundary - + return { - "X-RateLimit-Limit": str(config['requests_per_minute']), + "X-RateLimit-Limit": str(config["requests_per_minute"]), "X-RateLimit-Remaining": str(max(0, int(bucket.tokens))), "X-RateLimit-Reset": str(reset_time), - "X-RateLimit-Burst": str(config['burst_capacity']), + "X-RateLimit-Burst": str(config["burst_capacity"]), } - - def _create_rate_limit_response(self, retry_after: int, headers: Dict[str, str]) -> JSONResponse: + + def _create_rate_limit_response( + self, retry_after: int, headers: Dict[str, str] + ) -> JSONResponse: """Create standardized rate limit exceeded response""" headers["Retry-After"] = str(retry_after) headers["X-RateLimit-Retry-After"] = str(retry_after) - - rate_limit_response = RateLimitResponse( - retry_after=retry_after - ) - - return JSONResponse( - status_code=429, - content=rate_limit_response.dict(), - headers=headers - ) - + + rate_limit_response = RateLimitResponse(retry_after=retry_after) + + return JSONResponse(status_code=429, content=rate_limit_response.dict(), headers=headers) + def _track_suspicious_patterns(self, client_id: str, endpoint: str, request: Request): """Track patterns that might indicate suspicious behavior""" endpoint_lower = endpoint.lower() - + # Track error endpoints - if any(p in endpoint_lower for p in ['/error', '/debug', '/classify']): - self.store.track_suspicious_activity(client_id, 'error_endpoints') - + if any(p in endpoint_lower for p in ["/error", "/debug", "/classify"]): + self.store.track_suspicious_activity(client_id, "error_endpoints") + # Track validation endpoints - if 'validate' in endpoint_lower: - self.store.track_suspicious_activity(client_id, 'validation_endpoints') - + if "validate" in endpoint_lower: + self.store.track_suspicious_activity(client_id, "validation_endpoints") + # Track auth failures (would need response status in real implementation) - if any(p in endpoint_lower for p in ['/auth/', '/login', '/token']): - self.store.track_suspicious_activity(client_id, 'auth_attempts') - + if any(p in endpoint_lower for p in ["/auth/", "/login", "/token"]): + self.store.track_suspicious_activity(client_id, "auth_attempts") + def _detect_suspicious_behavior(self, client_id: str) -> List[str]: """Detect suspicious behavior patterns""" suspicious = [] - + # Check high error rate error_count = self.store.get_suspicious_activity_count( - client_id, 'error_endpoints', - self.SUSPICIOUS_PATTERNS['high_error_rate']['window_minutes'] + client_id, + "error_endpoints", + self.SUSPICIOUS_PATTERNS["high_error_rate"]["window_minutes"], ) - if error_count > self.SUSPICIOUS_PATTERNS['high_error_rate']['threshold']: - suspicious.append('high_error_rate') - + if error_count > self.SUSPICIOUS_PATTERNS["high_error_rate"]["threshold"]: + suspicious.append("high_error_rate") + # Check validation farming validation_count = self.store.get_suspicious_activity_count( - client_id, 'validation_endpoints', - self.SUSPICIOUS_PATTERNS['validation_farming']['window_minutes'] + client_id, + "validation_endpoints", + self.SUSPICIOUS_PATTERNS["validation_farming"]["window_minutes"], ) - if validation_count > self.SUSPICIOUS_PATTERNS['validation_farming']['threshold']: - suspicious.append('validation_farming') - + if validation_count > self.SUSPICIOUS_PATTERNS["validation_farming"]["threshold"]: + suspicious.append("validation_farming") + # Check auth brute force auth_count = self.store.get_suspicious_activity_count( - client_id, 'auth_attempts', - self.SUSPICIOUS_PATTERNS['auth_brute_force']['window_minutes'] + client_id, + "auth_attempts", + self.SUSPICIOUS_PATTERNS["auth_brute_force"]["window_minutes"], ) - if auth_count > self.SUSPICIOUS_PATTERNS['auth_brute_force']['threshold']: - suspicious.append('auth_brute_force') - + if auth_count > self.SUSPICIOUS_PATTERNS["auth_brute_force"]["threshold"]: + suspicious.append("auth_brute_force") + return suspicious - + def _get_user_id(self, request: Request) -> Optional[str]: """Extract user ID from request if available""" auth_header = request.headers.get("authorization") @@ -432,7 +439,7 @@ def _get_user_id(self, request: Request) -> Optional[str]: # Simplified - in production decode JWT return "authenticated_user" return None - + def _get_session_id(self, request: Request) -> Optional[str]: """Extract session ID from request if available""" session_id = request.cookies.get("session_id") @@ -444,6 +451,7 @@ def _get_session_id(self, request: Request) -> Optional[str]: # Global instance for dependency injection _rate_limiting_middleware = None + def get_rate_limiting_middleware() -> RateLimitingMiddleware: """Get or create the global rate limiting middleware""" global _rate_limiting_middleware @@ -451,5 +459,6 @@ def get_rate_limiting_middleware() -> RateLimitingMiddleware: _rate_limiting_middleware = RateLimitingMiddleware() return _rate_limiting_middleware + # Alias for backward compatibility -get_industry_standard_rate_limiter = get_rate_limiting_middleware \ No newline at end of file +get_industry_standard_rate_limiter = get_rate_limiting_middleware diff --git a/backend/app/models/authorization_models.py b/backend/app/models/authorization_models.py index a0e6d2e8..dc98beec 100644 --- a/backend/app/models/authorization_models.py +++ b/backend/app/models/authorization_models.py @@ -8,6 +8,7 @@ Design by Emily (Security Engineer) - Implements ReBAC with audit trail """ + import uuid from enum import Enum from typing import Dict, List, Optional, Set, Any @@ -18,6 +19,7 @@ class ResourceType(str, Enum): """Types of resources that can be protected by authorization""" + HOST = "host" HOST_GROUP = "host_group" SCAN = "scan" @@ -27,31 +29,35 @@ class ResourceType(str, Enum): class ActionType(str, Enum): """Actions that can be performed on resources""" + READ = "read" WRITE = "write" EXECUTE = "execute" DELETE = "delete" MANAGE = "manage" # Administrative actions - SCAN = "scan" # Specific to scan operations + SCAN = "scan" # Specific to scan operations EXPORT = "export" # Data export operations class PermissionEffect(str, Enum): """Effect of a permission policy""" + ALLOW = "allow" DENY = "deny" class PermissionScope(str, Enum): """Scope of permission application""" - DIRECT = "direct" # Direct resource access + + DIRECT = "direct" # Direct resource access INHERITED = "inherited" # Inherited from parent resource - GROUP = "group" # Through group membership - ROLE = "role" # Through role assignment + GROUP = "group" # Through group membership + ROLE = "role" # Through role assignment class AuthorizationDecision(str, Enum): """Final authorization decision""" + ALLOW = "allow" DENY = "deny" NOT_APPLICABLE = "not_applicable" @@ -60,19 +66,21 @@ class AuthorizationDecision(str, Enum): @dataclass class ResourceIdentifier: """Identifies a specific resource for authorization""" + resource_type: ResourceType resource_id: str parent_resource_id: Optional[str] = None attributes: Dict[str, Any] = None - + def __post_init__(self): if self.attributes is None: self.attributes = {} -@dataclass +@dataclass class PermissionPolicy: """Defines a specific permission policy""" + subject_type: str # user, group, role subject_id: str resource_type: ResourceType @@ -87,7 +95,7 @@ class PermissionPolicy: expires_at: Optional[datetime] = None created_by: str = None is_active: bool = True - + def __post_init__(self): if self.conditions is None: self.conditions = {} @@ -96,6 +104,7 @@ def __post_init__(self): @dataclass class AuthorizationContext: """Context information for authorization decisions""" + user_id: str user_roles: List[str] user_groups: List[str] @@ -104,7 +113,7 @@ class AuthorizationContext: session_id: Optional[str] = None request_time: datetime = Field(default_factory=datetime.utcnow) additional_attributes: Dict[str, Any] = None - + def __post_init__(self): if self.additional_attributes is None: self.additional_attributes = {} @@ -113,6 +122,7 @@ def __post_init__(self): @dataclass class AuthorizationResult: """Result of an authorization check""" + decision: AuthorizationDecision resource: ResourceIdentifier action: ActionType @@ -128,6 +138,7 @@ class AuthorizationResult: class BulkAuthorizationRequest(BaseModel): """Request for bulk authorization checking""" + user_id: str resources: List[ResourceIdentifier] action: ActionType @@ -138,6 +149,7 @@ class BulkAuthorizationRequest(BaseModel): class BulkAuthorizationResult(BaseModel): """Result of bulk authorization check""" + overall_decision: AuthorizationDecision individual_results: List[AuthorizationResult] denied_resources: List[ResourceIdentifier] @@ -149,6 +161,7 @@ class BulkAuthorizationResult(BaseModel): class HostPermission(BaseModel): """Specific host permission model""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) user_id: Optional[str] = None group_id: Optional[str] = None @@ -165,6 +178,7 @@ class HostPermission(BaseModel): class HostGroupPermission(BaseModel): """Host group permission model""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) user_id: Optional[str] = None group_id: Optional[str] = None @@ -182,6 +196,7 @@ class HostGroupPermission(BaseModel): class AuthorizationAuditEvent(BaseModel): """Audit event for authorization decisions""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) event_type: str # permission_check, policy_created, access_granted, access_denied user_id: str @@ -202,14 +217,16 @@ class AuthorizationAuditEvent(BaseModel): class PolicyConflictResolution(str, Enum): """How to resolve conflicting policies""" - DENY_OVERRIDES = "deny_overrides" # Deny takes precedence - ALLOW_OVERRIDES = "allow_overrides" # Allow takes precedence - FIRST_MATCH = "first_match" # First matching policy wins - PRIORITY_ORDER = "priority_order" # Higher priority wins + + DENY_OVERRIDES = "deny_overrides" # Deny takes precedence + ALLOW_OVERRIDES = "allow_overrides" # Allow takes precedence + FIRST_MATCH = "first_match" # First matching policy wins + PRIORITY_ORDER = "priority_order" # Higher priority wins class AuthorizationConfiguration(BaseModel): """Authorization system configuration""" + default_decision: AuthorizationDecision = AuthorizationDecision.DENY conflict_resolution: PolicyConflictResolution = PolicyConflictResolution.DENY_OVERRIDES cache_ttl_seconds: int = 300 # 5 minutes @@ -225,6 +242,7 @@ class AuthorizationConfiguration(BaseModel): # Database Models for SQLAlchemy class AuthorizationPolicy(BaseModel): """Database model for authorization policies""" + id: str name: str description: Optional[str] @@ -246,55 +264,60 @@ class AuthorizationPolicy(BaseModel): class PermissionCache: """In-memory cache for permission decisions""" - + def __init__(self, ttl_seconds: int = 300, max_size: int = 10000): self.cache: Dict[str, Dict] = {} self.ttl_seconds = ttl_seconds self.max_size = max_size self.access_times: Dict[str, datetime] = {} - + def _generate_key(self, user_id: str, resource: ResourceIdentifier, action: ActionType) -> str: """Generate cache key for permission check""" return f"{user_id}:{resource.resource_type.value}:{resource.resource_id}:{action.value}" - - def get(self, user_id: str, resource: ResourceIdentifier, action: ActionType) -> Optional[AuthorizationResult]: + + def get( + self, user_id: str, resource: ResourceIdentifier, action: ActionType + ) -> Optional[AuthorizationResult]: """Get cached permission decision""" key = self._generate_key(user_id, resource, action) - + if key not in self.cache: return None - + cached_item = self.cache[key] - cached_time = cached_item.get('timestamp') - + cached_time = cached_item.get("timestamp") + if not cached_time or datetime.utcnow() - cached_time > timedelta(seconds=self.ttl_seconds): # Cache expired del self.cache[key] if key in self.access_times: del self.access_times[key] return None - + # Update access time self.access_times[key] = datetime.utcnow() - - result = cached_item.get('result') + + result = cached_item.get("result") if result: result.cached = True - + return result - - def put(self, user_id: str, resource: ResourceIdentifier, action: ActionType, result: AuthorizationResult): + + def put( + self, + user_id: str, + resource: ResourceIdentifier, + action: ActionType, + result: AuthorizationResult, + ): """Cache permission decision""" if len(self.cache) >= self.max_size: self._evict_least_recently_used() - + key = self._generate_key(user_id, resource, action) - self.cache[key] = { - 'result': result, - 'timestamp': datetime.utcnow() - } + self.cache[key] = {"result": result, "timestamp": datetime.utcnow()} self.access_times[key] = datetime.utcnow() - + def invalidate_user(self, user_id: str): """Invalidate all cached permissions for a user""" keys_to_remove = [k for k in self.cache.keys() if k.startswith(f"{user_id}:")] @@ -302,7 +325,7 @@ def invalidate_user(self, user_id: str): del self.cache[key] if key in self.access_times: del self.access_times[key] - + def invalidate_resource(self, resource: ResourceIdentifier): """Invalidate all cached permissions for a resource""" resource_prefix = f"{resource.resource_type.value}:{resource.resource_id}" @@ -311,22 +334,22 @@ def invalidate_resource(self, resource: ResourceIdentifier): del self.cache[key] if key in self.access_times: del self.access_times[key] - + def clear(self): """Clear all cached permissions""" self.cache.clear() self.access_times.clear() - + def _evict_least_recently_used(self): """Evict least recently used cache entries""" if not self.access_times: return - + # Remove 10% of cache entries (oldest first) remove_count = max(1, len(self.access_times) // 10) sorted_keys = sorted(self.access_times.items(), key=lambda x: x[1]) - + for key, _ in sorted_keys[:remove_count]: if key in self.cache: del self.cache[key] - del self.access_times[key] \ No newline at end of file + del self.access_times[key] diff --git a/backend/app/models/error_models.py b/backend/app/models/error_models.py index a80fa515..5b4c106e 100644 --- a/backend/app/models/error_models.py +++ b/backend/app/models/error_models.py @@ -2,15 +2,18 @@ OpenWatch Error Models Provides both internal (with technical details) and sanitized (user-safe) error models """ + from enum import Enum from typing import List, Dict, Any, Optional from datetime import datetime from pydantic import BaseModel, Field + class ErrorCategory(str, Enum): """Error category classification""" + NETWORK = "network" - AUTHENTICATION = "authentication" + AUTHENTICATION = "authentication" PRIVILEGE = "privilege" RESOURCE = "resource" DEPENDENCY = "dependency" @@ -19,15 +22,19 @@ class ErrorCategory(str, Enum): CONFIGURATION = "configuration" SECURITY = "security" # Added for rate limiting and security events + class ErrorSeverity(str, Enum): """Error severity levels""" + CRITICAL = "critical" ERROR = "error" WARNING = "warning" INFO = "info" + class AutomatedFixResponse(BaseModel): """Sanitized automated fix response for users""" + fix_id: str description: str requires_sudo: bool = False @@ -35,8 +42,10 @@ class AutomatedFixResponse(BaseModel): is_safe: bool = True # Removed command and rollback_command fields for security + class ScanErrorInternal(BaseModel): """Internal scan error with full technical details (server-side only)""" + error_code: str category: ErrorCategory severity: ErrorSeverity @@ -49,8 +58,10 @@ class ScanErrorInternal(BaseModel): documentation_url: str = "" timestamp: datetime = Field(default_factory=datetime.utcnow) + class ScanErrorResponse(BaseModel): """Sanitized scan error response for users (no sensitive data)""" + error_code: str category: ErrorCategory severity: ErrorSeverity @@ -62,8 +73,10 @@ class ScanErrorResponse(BaseModel): documentation_url: str = "" timestamp: datetime = Field(default_factory=datetime.utcnow) + class ValidationResultInternal(BaseModel): """Internal validation result with full technical details""" + can_proceed: bool errors: List[ScanErrorInternal] = Field(default_factory=list) warnings: List[ScanErrorInternal] = Field(default_factory=list) @@ -71,8 +84,10 @@ class ValidationResultInternal(BaseModel): system_info: Dict[str, Any] = Field(default_factory=dict) # Contains sensitive data validation_checks: Dict[str, bool] = Field(default_factory=dict) + class ValidationResultResponse(BaseModel): """Sanitized validation result response for users""" + can_proceed: bool errors: List[ScanErrorResponse] = Field(default_factory=list) warnings: List[ScanErrorResponse] = Field(default_factory=list) @@ -80,21 +95,27 @@ class ValidationResultResponse(BaseModel): system_info: Dict[str, Any] = Field(default_factory=dict) # Now sanitized via Security Fix 5 validation_checks: Dict[str, bool] = Field(default_factory=dict) + class ErrorClassificationResponse(BaseModel): """Response from error classification endpoint""" + error: ScanErrorResponse request_id: str rate_limit_info: Dict[str, Any] = Field(default_factory=dict) + class RateLimitResponse(BaseModel): """Response when rate limit is exceeded""" + error_code: str = "RATE_LIMIT" message: str = "Request rate limit exceeded" retry_after: int # Seconds to wait documentation_url: str = "https://docs.openwatch.dev/security/rate-limits" + class SecurityAuditLog(BaseModel): """Security audit log entry (server-side only)""" + timestamp: datetime = Field(default_factory=datetime.utcnow) event_type: str error_code: str @@ -107,11 +128,13 @@ class SecurityAuditLog(BaseModel): request_path: Optional[str] = None user_agent: Optional[str] = None + class ErrorStatistics(BaseModel): """Error statistics for monitoring (sanitized)""" + total_errors: int = 0 errors_by_category: Dict[str, int] = Field(default_factory=dict) errors_by_severity: Dict[str, int] = Field(default_factory=dict) top_error_codes: List[Dict[str, Any]] = Field(default_factory=list) time_window: str = "1h" - # No IP addresses or user IDs exposed \ No newline at end of file + # No IP addresses or user IDs exposed diff --git a/backend/app/models/scan_models.py b/backend/app/models/scan_models.py index 2804f216..efdc799e 100644 --- a/backend/app/models/scan_models.py +++ b/backend/app/models/scan_models.py @@ -1,6 +1,7 @@ """ Group scan session models and data structures """ + from pydantic import BaseModel from typing import List, Optional, Dict, Any from datetime import datetime @@ -9,6 +10,7 @@ class ScanSessionStatus(str, Enum): """Status values for group scan sessions""" + PENDING = "pending" IN_PROGRESS = "in_progress" COMPLETED = "completed" @@ -18,6 +20,7 @@ class ScanSessionStatus(str, Enum): class HostScanStatus(str, Enum): """Status values for individual host scans within a group scan""" + PENDING = "pending" SCANNING = "scanning" COMPLETED = "completed" @@ -27,6 +30,7 @@ class HostScanStatus(str, Enum): class GroupScanConfig(BaseModel): """Configuration for group scan initiation""" + content_id: Optional[int] = None profile_id: Optional[str] = None scan_options: Optional[Dict[str, Any]] = {} @@ -38,6 +42,7 @@ class GroupScanConfig(BaseModel): class HostScanDetail(BaseModel): """Detailed status of a host within a group scan""" + host_id: str host_name: str hostname: str @@ -53,6 +58,7 @@ class HostScanDetail(BaseModel): class GroupScanSession(BaseModel): """Group scan session tracking information""" + session_id: str group_id: int group_name: str @@ -63,15 +69,16 @@ class GroupScanSession(BaseModel): actual_completion: Optional[datetime] = None status: ScanSessionStatus hosts_scanning: List[str] = [] # Host IDs currently being scanned - hosts_pending: List[str] = [] # Host IDs waiting to be scanned + hosts_pending: List[str] = [] # Host IDs waiting to be scanned hosts_completed: List[str] = [] # Host IDs with completed scans - hosts_failed: List[str] = [] # Host IDs with failed scans + hosts_failed: List[str] = [] # Host IDs with failed scans scan_config: Optional[GroupScanConfig] = None metadata: Optional[Dict[str, Any]] = {} class GroupScanProgress(BaseModel): """Real-time progress information for a group scan""" + session_id: str group_id: int group_name: str @@ -90,6 +97,7 @@ class GroupScanProgress(BaseModel): class GroupScanSummary(BaseModel): """Summary results for a completed group scan""" + session_id: str group_id: int group_name: str @@ -107,6 +115,7 @@ class GroupScanSummary(BaseModel): class ActiveScanSession(BaseModel): """Active scan session information for listing""" + session_id: str group_id: int group_name: str @@ -116,4 +125,4 @@ class ActiveScanSession(BaseModel): total_hosts: int started_at: datetime estimated_completion: Optional[datetime] = None - initiated_by: int \ No newline at end of file + initiated_by: int diff --git a/backend/app/models/system_models.py b/backend/app/models/system_models.py index cfef8810..3fec9b80 100644 --- a/backend/app/models/system_models.py +++ b/backend/app/models/system_models.py @@ -1,7 +1,7 @@ """ System Information Models for Security Fix 5: System Information Sanitization -Provides safe models for exposing only necessary system information while +Provides safe models for exposing only necessary system information while preventing reconnaissance attacks through detailed technical information exposure. """ @@ -13,27 +13,28 @@ class SystemInfoLevel(str, Enum): """Levels of system information exposure""" - BASIC = "basic" # Minimal info for compliance only + + BASIC = "basic" # Minimal info for compliance only COMPLIANCE = "compliance" # Info needed for compliance reporting - OPERATIONAL = "operational" # Info for system operations - ADMIN = "admin" # Full technical details (admin only) + OPERATIONAL = "operational" # Info for system operations + ADMIN = "admin" # Full technical details (admin only) class ComplianceSystemInfo(BaseModel): """Safe system information for compliance reporting""" + os_family: Optional[str] = None # e.g., "linux", "windows" (generic) compliance_relevant_info: Dict[str, Any] = Field(default_factory=dict) last_updated: datetime = Field(default_factory=datetime.utcnow) info_level: SystemInfoLevel = SystemInfoLevel.COMPLIANCE - + class Config: - json_encoders = { - datetime: lambda v: v.isoformat() - } + json_encoders = {datetime: lambda v: v.isoformat()} class OperationalSystemInfo(ComplianceSystemInfo): """System information for operational purposes""" + kernel_family: Optional[str] = None # e.g., "Linux", "NT" (generic) service_status: Dict[str, str] = Field(default_factory=dict) # sanitized service info resource_availability: Dict[str, Any] = Field(default_factory=dict) @@ -42,9 +43,10 @@ class OperationalSystemInfo(ComplianceSystemInfo): class AdminSystemInfo(ComplianceSystemInfo): """Full system information for administrators only""" + detailed_os_info: Optional[str] = None # Full os-release content kernel_version: Optional[str] = None - installed_packages: List[str] = Field(default_factory=list) + installed_packages: List[str] = Field(default_factory=list) network_configuration: Dict[str, Any] = Field(default_factory=dict) running_services: List[Dict[str, Any]] = Field(default_factory=list) system_details: Optional[str] = None # Full uname output @@ -53,18 +55,20 @@ class AdminSystemInfo(ComplianceSystemInfo): class SystemInfoSanitizationContext(BaseModel): """Context for system information sanitization""" + user_id: Optional[str] = None user_role: Optional[str] = None source_ip: Optional[str] = None access_level: SystemInfoLevel = SystemInfoLevel.BASIC is_admin: bool = False compliance_only: bool = True - - + + class SystemInfoFilter(BaseModel): """Configuration for filtering system information""" + allow_os_version: bool = False - allow_kernel_info: bool = False + allow_kernel_info: bool = False allow_package_info: bool = False allow_network_config: bool = False allow_service_info: bool = False @@ -74,6 +78,7 @@ class SystemInfoFilter(BaseModel): class SystemInfoMetadata(BaseModel): """Metadata about system information collection""" + collection_timestamp: datetime = Field(default_factory=datetime.utcnow) collection_method: str = "ssh_command" sanitization_applied: bool = True @@ -84,6 +89,7 @@ class SystemInfoMetadata(BaseModel): class SanitizedSystemValidation(BaseModel): """Sanitized validation result containing safe system information""" + can_proceed: bool system_compatible: bool = True compliance_info: ComplianceSystemInfo @@ -94,6 +100,7 @@ class SanitizedSystemValidation(BaseModel): class SystemReconnaissancePattern(BaseModel): """Pattern for detecting reconnaissance attempts""" + pattern_id: str description: str regex_pattern: str @@ -103,6 +110,7 @@ class SystemReconnaissancePattern(BaseModel): class SystemInfoAuditEvent(BaseModel): """Audit event for system information access""" + event_id: str timestamp: datetime = Field(default_factory=datetime.utcnow) user_id: Optional[str] = None @@ -112,4 +120,4 @@ class SystemInfoAuditEvent(BaseModel): admin_access: bool = False reconnaissance_detected: bool = False patterns_triggered: List[str] = Field(default_factory=list) - sanitization_applied: bool = True \ No newline at end of file + sanitization_applied: bool = True diff --git a/backend/app/plugins/__init__.py b/backend/app/plugins/__init__.py index 8e42c8c4..762c5d29 100644 --- a/backend/app/plugins/__init__.py +++ b/backend/app/plugins/__init__.py @@ -1,4 +1,4 @@ """ OpenWatch Plugin System Plugin architecture foundation for extensible SCAP scanning capabilities -""" \ No newline at end of file +""" diff --git a/backend/app/plugins/interface.py b/backend/app/plugins/interface.py index 9c26f545..cff2e46b 100644 --- a/backend/app/plugins/interface.py +++ b/backend/app/plugins/interface.py @@ -2,6 +2,7 @@ OpenWatch Plugin Interface Specification Defines the core plugin architecture for extensible SCAP scanning functionality """ + from abc import ABC, abstractmethod from typing import Dict, List, Optional, Any, AsyncGenerator from dataclasses import dataclass @@ -13,18 +14,20 @@ class PluginType(Enum): """Plugin types supported by OpenWatch""" - SCANNER = "scanner" # Custom scanning engines - REPORTER = "reporter" # Custom report generation - REMEDIATION = "remediation" # Automated remediation providers - INTEGRATION = "integration" # External system integrations - CONTENT = "content" # SCAP content providers - AUTH = "auth" # Authentication providers - NOTIFICATION = "notification" # Notification services + + SCANNER = "scanner" # Custom scanning engines + REPORTER = "reporter" # Custom report generation + REMEDIATION = "remediation" # Automated remediation providers + INTEGRATION = "integration" # External system integrations + CONTENT = "content" # SCAP content providers + AUTH = "auth" # Authentication providers + NOTIFICATION = "notification" # Notification services @dataclass class PluginMetadata: """Plugin metadata information""" + name: str version: str description: str @@ -33,7 +36,7 @@ class PluginMetadata: supported_api_version: str = "1.0.0" dependencies: List[str] = None config_schema: Dict = None - + def __post_init__(self): if self.dependencies is None: self.dependencies = [] @@ -44,6 +47,7 @@ def __post_init__(self): @dataclass class ScanContext: """Context information for scan operations""" + scan_id: str profile_id: str content_path: str @@ -52,7 +56,7 @@ class ScanContext: rule_id: Optional[str] = None user_id: Optional[str] = None scan_parameters: Dict = None - + def __post_init__(self): if self.scan_parameters is None: self.scan_parameters = {} @@ -61,6 +65,7 @@ def __post_init__(self): @dataclass class ScanResult: """Standardized scan result format""" + scan_id: str hostname: str status: str # 'completed', 'failed', 'error' @@ -73,7 +78,7 @@ class ScanResult: failed_rules: List[Dict] = None rule_details: List[Dict] = None metadata: Dict = None - + def __post_init__(self): if self.failed_rules is None: self.failed_rules = [] @@ -85,40 +90,40 @@ def __post_init__(self): class PluginInterface(ABC): """Base interface for all OpenWatch plugins""" - + def __init__(self, config: Dict = None): self.config = config or {} self.metadata: Optional[PluginMetadata] = None self.enabled = True self.logger = logging.getLogger(f"plugin.{self.__class__.__name__}") - + @abstractmethod def get_metadata(self) -> PluginMetadata: """Return plugin metadata""" pass - + @abstractmethod async def initialize(self) -> bool: """Initialize the plugin. Return True if successful.""" pass - + @abstractmethod async def cleanup(self) -> bool: """Cleanup plugin resources. Return True if successful.""" pass - - async def health_check(self) -> Dict: + + def health_check(self) -> Dict: """Perform plugin health check""" return { "status": "healthy" if self.enabled else "disabled", "plugin": self.get_metadata().name, - "version": self.get_metadata().version + "version": self.get_metadata().version, } - + def is_enabled(self) -> bool: """Check if plugin is enabled""" return self.enabled - + def set_enabled(self, enabled: bool): """Enable or disable the plugin""" self.enabled = enabled @@ -126,147 +131,149 @@ def set_enabled(self, enabled: bool): class ScannerPlugin(PluginInterface): """Interface for custom scanning engine plugins""" - + @abstractmethod async def can_scan_host(self, host_config: Dict) -> bool: """Check if this plugin can scan the specified host""" pass - + @abstractmethod async def execute_scan(self, context: ScanContext) -> ScanResult: """Execute a scan using this plugin""" pass - + @abstractmethod async def validate_content(self, content_path: str) -> bool: """Validate SCAP content compatibility with this scanner""" pass - - async def get_supported_profiles(self, content_path: str) -> List[Dict]: + + def get_supported_profiles(self, content_path: str) -> List[Dict]: """Get profiles supported by this scanner""" return [] class ReporterPlugin(PluginInterface): """Interface for custom report generation plugins""" - + @abstractmethod - async def generate_report(self, scan_results: List[ScanResult], - format_type: str = "html") -> bytes: + async def generate_report( + self, scan_results: List[ScanResult], format_type: str = "html" + ) -> bytes: """Generate a report from scan results""" pass - + @abstractmethod def get_supported_formats(self) -> List[str]: """Get list of supported report formats""" pass - - async def get_report_template(self, format_type: str) -> Optional[str]: + + def get_report_template(self, format_type: str) -> Optional[str]: """Get report template for the specified format""" return None class RemediationPlugin(PluginInterface): """Interface for automated remediation plugins""" - + @abstractmethod async def can_remediate_rule(self, rule_id: str, host_config: Dict) -> bool: """Check if this plugin can remediate the specified rule""" pass - + @abstractmethod - async def execute_remediation(self, rule_id: str, host_config: Dict, - scan_result: ScanResult) -> Dict: + async def execute_remediation( + self, rule_id: str, host_config: Dict, scan_result: ScanResult + ) -> Dict: """Execute remediation for a failed rule""" pass - + @abstractmethod - async def get_remediation_plan(self, failed_rules: List[str], - host_config: Dict) -> Dict: + async def get_remediation_plan(self, failed_rules: List[str], host_config: Dict) -> Dict: """Get remediation plan for multiple failed rules""" pass - - async def validate_remediation(self, rule_id: str, host_config: Dict) -> Dict: + + def validate_remediation(self, rule_id: str, host_config: Dict) -> Dict: """Validate that remediation was successful""" return {"status": "unknown", "validated": False} class IntegrationPlugin(PluginInterface): """Interface for external system integration plugins""" - + @abstractmethod - async def export_results(self, scan_results: List[ScanResult], - destination_config: Dict) -> bool: + async def export_results( + self, scan_results: List[ScanResult], destination_config: Dict + ) -> bool: """Export scan results to external system""" pass - + @abstractmethod async def import_content(self, source_config: Dict) -> Optional[str]: """Import SCAP content from external source""" pass - - async def sync_hosts(self, source_config: Dict) -> List[Dict]: + + def sync_hosts(self, source_config: Dict) -> List[Dict]: """Synchronize host inventory from external system""" return [] class ContentPlugin(PluginInterface): """Interface for SCAP content provider plugins""" - + @abstractmethod async def fetch_content(self, content_id: str, version: str = "latest") -> str: """Fetch SCAP content by identifier""" pass - + @abstractmethod async def list_available_content(self) -> List[Dict]: """List available SCAP content from this provider""" pass - + @abstractmethod async def validate_content_integrity(self, content_path: str) -> bool: """Validate content integrity and authenticity""" pass - - async def get_content_metadata(self, content_id: str) -> Dict: + + def get_content_metadata(self, content_id: str) -> Dict: """Get metadata for specific content""" return {} class AuthenticationPlugin(PluginInterface): """Interface for authentication provider plugins""" - + @abstractmethod async def authenticate_user(self, credentials: Dict) -> Optional[Dict]: """Authenticate user and return user info if successful""" pass - + @abstractmethod - async def authorize_action(self, user_info: Dict, action: str, - resource: str) -> bool: + async def authorize_action(self, user_info: Dict, action: str, resource: str) -> bool: """Check if user is authorized for specific action on resource""" pass - - async def get_user_groups(self, user_info: Dict) -> List[str]: + + def get_user_groups(self, user_info: Dict) -> List[str]: """Get list of groups for authenticated user""" return [] class NotificationPlugin(PluginInterface): """Interface for notification service plugins""" - + @abstractmethod - async def send_notification(self, message: str, recipients: List[str], - notification_type: str = "info") -> bool: + async def send_notification( + self, message: str, recipients: List[str], notification_type: str = "info" + ) -> bool: """Send notification message""" pass - + @abstractmethod def get_supported_types(self) -> List[str]: """Get supported notification types""" pass - - async def validate_recipients(self, recipients: List[str]) -> List[str]: + + def validate_recipients(self, recipients: List[str]) -> List[str]: """Validate and return valid recipients""" return recipients @@ -274,25 +281,25 @@ async def validate_recipients(self, recipients: List[str]) -> List[str]: # Plugin Hook Definitions class PluginHooks: """Defines available plugin hooks in the OpenWatch system""" - + # Scan lifecycle hooks BEFORE_SCAN = "before_scan" AFTER_SCAN = "after_scan" SCAN_FAILED = "scan_failed" - + # Report generation hooks BEFORE_REPORT = "before_report" AFTER_REPORT = "after_report" - + # Host management hooks HOST_ADDED = "host_added" HOST_REMOVED = "host_removed" HOST_UPDATED = "host_updated" - + # System hooks SYSTEM_STARTUP = "system_startup" SYSTEM_SHUTDOWN = "system_shutdown" - + # Security hooks LOGIN_SUCCESS = "login_success" LOGIN_FAILED = "login_failed" @@ -302,12 +309,13 @@ class PluginHooks: @dataclass class PluginHookContext: """Context passed to plugin hooks""" + hook_name: str timestamp: str data: Dict user_id: Optional[str] = None session_id: Optional[str] = None - + def __post_init__(self): if self.data is None: self.data = {} @@ -315,30 +323,30 @@ def __post_init__(self): class HookablePlugin(PluginInterface): """Base class for plugins that can register hooks""" - + def __init__(self, config: Dict = None): super().__init__(config) self.registered_hooks: List[str] = [] - + @abstractmethod async def handle_hook(self, context: PluginHookContext) -> Optional[Dict]: """Handle a plugin hook""" pass - + def register_hook(self, hook_name: str): """Register interest in a specific hook""" if hook_name not in self.registered_hooks: self.registered_hooks.append(hook_name) - + def get_registered_hooks(self) -> List[str]: """Get list of registered hooks""" return self.registered_hooks.copy() # Utility functions for plugin development -def create_plugin_metadata(name: str, version: str, description: str, - author: str, plugin_type: PluginType, - **kwargs) -> PluginMetadata: +def create_plugin_metadata( + name: str, version: str, description: str, author: str, plugin_type: PluginType, **kwargs +) -> PluginMetadata: """Utility function to create plugin metadata""" return PluginMetadata( name=name, @@ -346,12 +354,13 @@ def create_plugin_metadata(name: str, version: str, description: str, description=description, author=author, plugin_type=plugin_type, - **kwargs + **kwargs, ) -def create_scan_context(scan_id: str, profile_id: str, content_path: str, - target_host: str, scan_type: str, **kwargs) -> ScanContext: +def create_scan_context( + scan_id: str, profile_id: str, content_path: str, target_host: str, scan_type: str, **kwargs +) -> ScanContext: """Utility function to create scan context""" return ScanContext( scan_id=scan_id, @@ -359,17 +368,14 @@ def create_scan_context(scan_id: str, profile_id: str, content_path: str, content_path=content_path, target_host=target_host, scan_type=scan_type, - **kwargs + **kwargs, ) -def create_scan_result(scan_id: str, hostname: str, status: str, - timestamp: str, **kwargs) -> ScanResult: +def create_scan_result( + scan_id: str, hostname: str, status: str, timestamp: str, **kwargs +) -> ScanResult: """Utility function to create scan result""" return ScanResult( - scan_id=scan_id, - hostname=hostname, - status=status, - timestamp=timestamp, - **kwargs - ) \ No newline at end of file + scan_id=scan_id, hostname=hostname, status=status, timestamp=timestamp, **kwargs + ) diff --git a/backend/app/plugins/manager.py b/backend/app/plugins/manager.py index f9c7e71c..aa5842d6 100644 --- a/backend/app/plugins/manager.py +++ b/backend/app/plugins/manager.py @@ -2,6 +2,7 @@ OpenWatch Plugin Manager Handles plugin discovery, loading, lifecycle management, and hook execution """ + import os import sys import asyncio @@ -14,10 +15,19 @@ from datetime import datetime from .interface import ( - PluginInterface, PluginMetadata, PluginType, PluginHooks, - PluginHookContext, HookablePlugin, - ScannerPlugin, ReporterPlugin, RemediationPlugin, - IntegrationPlugin, ContentPlugin, AuthenticationPlugin, NotificationPlugin + PluginInterface, + PluginMetadata, + PluginType, + PluginHooks, + PluginHookContext, + HookablePlugin, + ScannerPlugin, + ReporterPlugin, + RemediationPlugin, + IntegrationPlugin, + ContentPlugin, + AuthenticationPlugin, + NotificationPlugin, ) logger = logging.getLogger(__name__) @@ -25,6 +35,7 @@ class PluginLoadError(Exception): """Exception raised when plugin loading fails""" + pass @@ -33,20 +44,19 @@ class PluginManager: Central plugin manager for OpenWatch Handles plugin discovery, loading, configuration, and execution """ - - def __init__(self, plugins_dir: str = "/app/plugins", - config_dir: str = "/app/config/plugins"): + + def __init__(self, plugins_dir: str = "/app/plugins", config_dir: str = "/app/config/plugins"): self.plugins_dir = Path(plugins_dir) self.config_dir = Path(config_dir) self.loaded_plugins: Dict[str, PluginInterface] = {} self.plugin_configs: Dict[str, Dict] = {} self.hook_registry: Dict[str, List[HookablePlugin]] = {} self.plugin_dependencies: Dict[str, List[str]] = {} - + # Ensure directories exist self.plugins_dir.mkdir(parents=True, exist_ok=True) self.config_dir.mkdir(parents=True, exist_ok=True) - + # Plugin type mapping self.plugin_type_map = { PluginType.SCANNER: ScannerPlugin, @@ -55,41 +65,41 @@ def __init__(self, plugins_dir: str = "/app/plugins", PluginType.INTEGRATION: IntegrationPlugin, PluginType.CONTENT: ContentPlugin, PluginType.AUTH: AuthenticationPlugin, - PluginType.NOTIFICATION: NotificationPlugin + PluginType.NOTIFICATION: NotificationPlugin, } - + async def initialize(self) -> bool: """Initialize the plugin manager and load all plugins""" try: logger.info("Initializing OpenWatch Plugin Manager") - + # Load plugin configurations await self._load_plugin_configs() - + # Discover and load plugins await self._discover_plugins() - + # Initialize all loaded plugins await self._initialize_plugins() - + # Register plugin hooks await self._register_plugin_hooks() - + logger.info(f"Plugin manager initialized with {len(self.loaded_plugins)} plugins") return True - + except Exception as e: logger.error(f"Failed to initialize plugin manager: {e}") return False - + async def shutdown(self) -> bool: """Shutdown the plugin manager and cleanup all plugins""" try: logger.info("Shutting down plugin manager") - + # Execute system shutdown hooks await self.execute_hook(PluginHooks.SYSTEM_SHUTDOWN, {}) - + # Cleanup all plugins for plugin_name, plugin in self.loaded_plugins.items(): try: @@ -97,70 +107,70 @@ async def shutdown(self) -> bool: logger.debug(f"Cleaned up plugin: {plugin_name}") except Exception as e: logger.error(f"Error cleaning up plugin {plugin_name}: {e}") - + self.loaded_plugins.clear() self.hook_registry.clear() - + logger.info("Plugin manager shutdown complete") return True - + except Exception as e: logger.error(f"Error during plugin manager shutdown: {e}") return False - + async def load_plugin(self, plugin_path: str, plugin_name: str = None) -> bool: """Load a single plugin from the specified path""" try: if not plugin_name: plugin_name = Path(plugin_path).stem - + logger.info(f"Loading plugin: {plugin_name} from {plugin_path}") - + # Load plugin module spec = importlib.util.spec_from_file_location(plugin_name, plugin_path) if not spec or not spec.loader: raise PluginLoadError(f"Cannot load plugin spec from {plugin_path}") - + module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - + # Find plugin class plugin_class = self._find_plugin_class(module) if not plugin_class: raise PluginLoadError(f"No valid plugin class found in {plugin_path}") - + # Get plugin configuration plugin_config = self.plugin_configs.get(plugin_name, {}) - + # Instantiate plugin plugin = plugin_class(plugin_config) - + # Validate plugin if not await self._validate_plugin(plugin): raise PluginLoadError(f"Plugin validation failed: {plugin_name}") - + # Initialize plugin if not await plugin.initialize(): raise PluginLoadError(f"Plugin initialization failed: {plugin_name}") - + # Store plugin self.loaded_plugins[plugin_name] = plugin - + # Register hooks if applicable if isinstance(plugin, HookablePlugin): await self._register_plugin_hooks_for(plugin) - + logger.info(f"Successfully loaded plugin: {plugin_name}") return True - + except Exception as e: logger.error(f"Failed to load plugin {plugin_name}: {e}") return False - + def get_plugin(self, plugin_name: str) -> Optional[PluginInterface]: """Get a loaded plugin by name""" return self.loaded_plugins.get(plugin_name) - + def get_plugins_by_type(self, plugin_type: PluginType) -> List[PluginInterface]: """Get all loaded plugins of the specified type""" plugins = [] @@ -168,7 +178,7 @@ def get_plugins_by_type(self, plugin_type: PluginType) -> List[PluginInterface]: if plugin.get_metadata().plugin_type == plugin_type: plugins.append(plugin) return plugins - + def list_plugins(self) -> Dict[str, Dict]: """List all loaded plugins with their metadata""" plugin_list = {} @@ -180,11 +190,11 @@ def list_plugins(self) -> Dict[str, Dict]: "description": metadata.description, "author": metadata.author, "type": metadata.plugin_type.value, - "enabled": plugin.is_enabled() + "enabled": plugin.is_enabled(), } return plugin_list - - async def enable_plugin(self, plugin_name: str) -> bool: + + def enable_plugin(self, plugin_name: str) -> bool: """Enable a plugin""" plugin = self.get_plugin(plugin_name) if plugin: @@ -192,8 +202,8 @@ async def enable_plugin(self, plugin_name: str) -> bool: logger.info(f"Enabled plugin: {plugin_name}") return True return False - - async def disable_plugin(self, plugin_name: str) -> bool: + + def disable_plugin(self, plugin_name: str) -> bool: """Disable a plugin""" plugin = self.get_plugin(plugin_name) if plugin: @@ -201,43 +211,38 @@ async def disable_plugin(self, plugin_name: str) -> bool: logger.info(f"Disabled plugin: {plugin_name}") return True return False - - async def execute_hook(self, hook_name: str, data: Dict, - user_id: str = None, session_id: str = None) -> List[Dict]: + + async def execute_hook( + self, hook_name: str, data: Dict, user_id: str = None, session_id: str = None + ) -> List[Dict]: """Execute all registered hooks for the specified event""" results = [] - + if hook_name not in self.hook_registry: return results - + hook_context = PluginHookContext( hook_name=hook_name, timestamp=datetime.now().isoformat(), data=data, user_id=user_id, - session_id=session_id + session_id=session_id, ) - + for plugin in self.hook_registry[hook_name]: if not plugin.is_enabled(): continue - + try: result = await plugin.handle_hook(hook_context) if result: - results.append({ - "plugin": plugin.get_metadata().name, - "result": result - }) + results.append({"plugin": plugin.get_metadata().name, "result": result}) except Exception as e: logger.error(f"Hook execution failed for plugin {plugin.get_metadata().name}: {e}") - results.append({ - "plugin": plugin.get_metadata().name, - "error": str(e) - }) - + results.append({"plugin": plugin.get_metadata().name, "error": str(e)}) + return results - + async def health_check(self) -> Dict: """Perform health check on all plugins""" health_status = { @@ -245,123 +250,126 @@ async def health_check(self) -> Dict: "total_plugins": len(self.loaded_plugins), "enabled_plugins": 0, "disabled_plugins": 0, - "plugin_health": {} + "plugin_health": {}, } - + for name, plugin in self.loaded_plugins.items(): try: plugin_health = await plugin.health_check() health_status["plugin_health"][name] = plugin_health - + if plugin.is_enabled(): health_status["enabled_plugins"] += 1 else: health_status["disabled_plugins"] += 1 - + except Exception as e: - health_status["plugin_health"][name] = { - "status": "error", - "error": str(e) - } - + health_status["plugin_health"][name] = {"status": "error", "error": str(e)} + return health_status - + # Scanner Plugin Helpers async def find_compatible_scanner(self, host_config: Dict) -> Optional[ScannerPlugin]: """Find a scanner plugin that can handle the specified host""" scanners = self.get_plugins_by_type(PluginType.SCANNER) - + for scanner in scanners: if scanner.is_enabled() and await scanner.can_scan_host(host_config): return scanner - + return None - + # Reporter Plugin Helpers - async def generate_report(self, scan_results: List, format_type: str = "html") -> Optional[bytes]: + async def generate_report( + self, scan_results: List, format_type: str = "html" + ) -> Optional[bytes]: """Generate a report using available reporter plugins""" reporters = self.get_plugins_by_type(PluginType.REPORTER) - + for reporter in reporters: - if (reporter.is_enabled() and - format_type in reporter.get_supported_formats()): + if reporter.is_enabled() and format_type in reporter.get_supported_formats(): try: return await reporter.generate_report(scan_results, format_type) except Exception as e: - logger.error(f"Report generation failed with plugin {reporter.get_metadata().name}: {e}") - + logger.error( + f"Report generation failed with plugin {reporter.get_metadata().name}: {e}" + ) + return None - + # Remediation Plugin Helpers - async def find_remediation_plugins(self, rule_id: str, host_config: Dict) -> List[RemediationPlugin]: + async def find_remediation_plugins( + self, rule_id: str, host_config: Dict + ) -> List[RemediationPlugin]: """Find remediation plugins that can handle the specified rule""" remediation_plugins = self.get_plugins_by_type(PluginType.REMEDIATION) compatible_plugins = [] - + for plugin in remediation_plugins: - if (plugin.is_enabled() and - await plugin.can_remediate_rule(rule_id, host_config)): + if plugin.is_enabled() and await plugin.can_remediate_rule(rule_id, host_config): compatible_plugins.append(plugin) - + return compatible_plugins - + # Private methods async def _discover_plugins(self): """Discover plugins in the plugins directory""" logger.info(f"Discovering plugins in: {self.plugins_dir}") - + for plugin_dir in self.plugins_dir.iterdir(): - if plugin_dir.is_dir() and not plugin_dir.name.startswith('.'): + if plugin_dir.is_dir() and not plugin_dir.name.startswith("."): plugin_file = plugin_dir / "plugin.py" if plugin_file.exists(): await self.load_plugin(str(plugin_file), plugin_dir.name) - - async def _load_plugin_configs(self): + + def _load_plugin_configs(self): """Load plugin configurations from config directory""" for config_file in self.config_dir.glob("*.json"): try: - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = json.load(f) plugin_name = config_file.stem self.plugin_configs[plugin_name] = config logger.debug(f"Loaded config for plugin: {plugin_name}") except Exception as e: logger.error(f"Failed to load config for {config_file}: {e}") - + def _find_plugin_class(self, module) -> Optional[Type[PluginInterface]]: """Find the plugin class in the loaded module""" for attr_name in dir(module): attr = getattr(module, attr_name) - if (isinstance(attr, type) and - issubclass(attr, PluginInterface) and - attr != PluginInterface): + if ( + isinstance(attr, type) + and issubclass(attr, PluginInterface) + and attr != PluginInterface + ): return attr return None - - async def _validate_plugin(self, plugin: PluginInterface) -> bool: + + def _validate_plugin(self, plugin: PluginInterface) -> bool: """Validate a plugin meets requirements""" try: metadata = plugin.get_metadata() - + # Basic validation if not metadata.name or not metadata.version: return False - + # Check plugin type if metadata.plugin_type not in self.plugin_type_map: return False - + # Check if plugin implements required interface required_interface = self.plugin_type_map[metadata.plugin_type] if not isinstance(plugin, required_interface): return False - + return True - + except Exception as e: logger.error(f"Plugin validation error: {e}") return False - + async def _initialize_plugins(self): """Initialize all loaded plugins""" # Sort plugins by dependencies (simplified for now) @@ -371,14 +379,14 @@ async def _initialize_plugins(self): logger.error(f"Failed to initialize plugin: {plugin_name}") except Exception as e: logger.error(f"Error initializing plugin {plugin_name}: {e}") - + async def _register_plugin_hooks(self): """Register hooks for all hookable plugins""" for plugin in self.loaded_plugins.values(): if isinstance(plugin, HookablePlugin): await self._register_plugin_hooks_for(plugin) - - async def _register_plugin_hooks_for(self, plugin: HookablePlugin): + + def _register_plugin_hooks_for(self, plugin: HookablePlugin): """Register hooks for a specific plugin""" for hook_name in plugin.get_registered_hooks(): if hook_name not in self.hook_registry: @@ -408,4 +416,4 @@ async def initialize_plugin_system() -> bool: async def shutdown_plugin_system() -> bool: """Shutdown the global plugin system""" manager = get_plugin_manager() - return await manager.shutdown() \ No newline at end of file + return await manager.shutdown() diff --git a/backend/app/rbac.py b/backend/app/rbac.py index 7038d05b..5e1318cd 100644 --- a/backend/app/rbac.py +++ b/backend/app/rbac.py @@ -2,6 +2,7 @@ Role-Based Access Control (RBAC) System for OpenWatch Defines permissions, roles, and access control logic """ + from enum import Enum from typing import List, Dict, Set, Optional, Any from functools import wraps @@ -13,26 +14,27 @@ class Permission(str, Enum): """System permissions""" + # User Management USER_CREATE = "user:create" - USER_READ = "user:read" + USER_READ = "user:read" USER_UPDATE = "user:update" USER_DELETE = "user:delete" USER_MANAGE_ROLES = "user:manage_roles" - + # Host Management HOST_CREATE = "host:create" HOST_READ = "host:read" HOST_UPDATE = "host:update" HOST_DELETE = "host:delete" HOST_MANAGE_ACCESS = "host:manage_access" - + # SCAP Content Management CONTENT_CREATE = "content:create" CONTENT_READ = "content:read" CONTENT_UPDATE = "content:update" CONTENT_DELETE = "content:delete" - + # Scan Operations SCAN_CREATE = "scan:create" SCAN_READ = "scan:read" @@ -42,19 +44,19 @@ class Permission(str, Enum): SCAN_WRITE = "scan:write" SCAN_APPROVE = "scan:approve" SCAN_ROLLBACK = "scan:rollback" - + # Results and Reports RESULTS_READ = "results:read" RESULTS_READ_ALL = "results:read_all" REPORTS_GENERATE = "reports:generate" REPORTS_EXPORT = "reports:export" - + # System Administration SYSTEM_CONFIG = "system:config" SYSTEM_CREDENTIALS = "system:credentials" SYSTEM_LOGS = "system:logs" SYSTEM_MAINTENANCE = "system:maintenance" - + # Audit and Compliance AUDIT_READ = "audit:read" COMPLIANCE_VIEW = "compliance:view" @@ -63,6 +65,7 @@ class Permission(str, Enum): class UserRole(str, Enum): """User roles in the system""" + SUPER_ADMIN = "super_admin" SECURITY_ADMIN = "security_admin" SECURITY_ANALYST = "security_analyst" @@ -75,99 +78,143 @@ class UserRole(str, Enum): ROLE_PERMISSIONS: Dict[UserRole, List[Permission]] = { UserRole.SUPER_ADMIN: [ # All permissions - super admin has full access - Permission.USER_CREATE, Permission.USER_READ, Permission.USER_UPDATE, - Permission.USER_DELETE, Permission.USER_MANAGE_ROLES, - Permission.HOST_CREATE, Permission.HOST_READ, Permission.HOST_UPDATE, - Permission.HOST_DELETE, Permission.HOST_MANAGE_ACCESS, - Permission.CONTENT_CREATE, Permission.CONTENT_READ, Permission.CONTENT_UPDATE, + Permission.USER_CREATE, + Permission.USER_READ, + Permission.USER_UPDATE, + Permission.USER_DELETE, + Permission.USER_MANAGE_ROLES, + Permission.HOST_CREATE, + Permission.HOST_READ, + Permission.HOST_UPDATE, + Permission.HOST_DELETE, + Permission.HOST_MANAGE_ACCESS, + Permission.CONTENT_CREATE, + Permission.CONTENT_READ, + Permission.CONTENT_UPDATE, Permission.CONTENT_DELETE, - Permission.SCAN_CREATE, Permission.SCAN_READ, Permission.SCAN_UPDATE, - Permission.SCAN_DELETE, Permission.SCAN_EXECUTE, Permission.SCAN_WRITE, - Permission.SCAN_APPROVE, Permission.SCAN_ROLLBACK, - Permission.RESULTS_READ, Permission.RESULTS_READ_ALL, Permission.REPORTS_GENERATE, + Permission.SCAN_CREATE, + Permission.SCAN_READ, + Permission.SCAN_UPDATE, + Permission.SCAN_DELETE, + Permission.SCAN_EXECUTE, + Permission.SCAN_WRITE, + Permission.SCAN_APPROVE, + Permission.SCAN_ROLLBACK, + Permission.RESULTS_READ, + Permission.RESULTS_READ_ALL, + Permission.REPORTS_GENERATE, Permission.REPORTS_EXPORT, - Permission.SYSTEM_CONFIG, Permission.SYSTEM_CREDENTIALS, Permission.SYSTEM_LOGS, + Permission.SYSTEM_CONFIG, + Permission.SYSTEM_CREDENTIALS, + Permission.SYSTEM_LOGS, Permission.SYSTEM_MAINTENANCE, - Permission.AUDIT_READ, Permission.COMPLIANCE_VIEW, Permission.COMPLIANCE_EXPORT + Permission.AUDIT_READ, + Permission.COMPLIANCE_VIEW, + Permission.COMPLIANCE_EXPORT, ], - UserRole.SECURITY_ADMIN: [ # Security-focused administration Permission.USER_READ, # Can view users but not create/delete - Permission.HOST_CREATE, Permission.HOST_READ, Permission.HOST_UPDATE, - Permission.HOST_DELETE, Permission.HOST_MANAGE_ACCESS, - Permission.CONTENT_CREATE, Permission.CONTENT_READ, Permission.CONTENT_UPDATE, + Permission.HOST_CREATE, + Permission.HOST_READ, + Permission.HOST_UPDATE, + Permission.HOST_DELETE, + Permission.HOST_MANAGE_ACCESS, + Permission.CONTENT_CREATE, + Permission.CONTENT_READ, + Permission.CONTENT_UPDATE, Permission.CONTENT_DELETE, - Permission.SCAN_CREATE, Permission.SCAN_READ, Permission.SCAN_UPDATE, - Permission.SCAN_DELETE, Permission.SCAN_EXECUTE, Permission.SCAN_WRITE, - Permission.SCAN_APPROVE, Permission.SCAN_ROLLBACK, - Permission.RESULTS_READ, Permission.RESULTS_READ_ALL, Permission.REPORTS_GENERATE, + Permission.SCAN_CREATE, + Permission.SCAN_READ, + Permission.SCAN_UPDATE, + Permission.SCAN_DELETE, + Permission.SCAN_EXECUTE, + Permission.SCAN_WRITE, + Permission.SCAN_APPROVE, + Permission.SCAN_ROLLBACK, + Permission.RESULTS_READ, + Permission.RESULTS_READ_ALL, + Permission.REPORTS_GENERATE, Permission.REPORTS_EXPORT, Permission.SYSTEM_LOGS, # Can view system logs - Permission.AUDIT_READ, Permission.COMPLIANCE_VIEW, Permission.COMPLIANCE_EXPORT + Permission.AUDIT_READ, + Permission.COMPLIANCE_VIEW, + Permission.COMPLIANCE_EXPORT, ], - UserRole.SECURITY_ANALYST: [ # Day-to-day security operations - Permission.HOST_READ, Permission.HOST_UPDATE, # Can manage assigned hosts + Permission.HOST_READ, + Permission.HOST_UPDATE, # Can manage assigned hosts Permission.CONTENT_READ, # Read-only SCAP content - Permission.SCAN_CREATE, Permission.SCAN_READ, Permission.SCAN_EXECUTE, Permission.SCAN_WRITE, - Permission.RESULTS_READ, Permission.REPORTS_GENERATE, Permission.REPORTS_EXPORT, - Permission.COMPLIANCE_VIEW + Permission.SCAN_CREATE, + Permission.SCAN_READ, + Permission.SCAN_EXECUTE, + Permission.SCAN_WRITE, + Permission.RESULTS_READ, + Permission.REPORTS_GENERATE, + Permission.REPORTS_EXPORT, + Permission.COMPLIANCE_VIEW, ], - UserRole.COMPLIANCE_OFFICER: [ # Compliance and reporting focus Permission.HOST_READ, # Read-only host access Permission.CONTENT_READ, # Read-only SCAP content Permission.SCAN_READ, # Read-only scan access - Permission.RESULTS_READ, Permission.RESULTS_READ_ALL, Permission.REPORTS_GENERATE, + Permission.RESULTS_READ, + Permission.RESULTS_READ_ALL, + Permission.REPORTS_GENERATE, Permission.REPORTS_EXPORT, - Permission.AUDIT_READ, Permission.COMPLIANCE_VIEW, Permission.COMPLIANCE_EXPORT + Permission.AUDIT_READ, + Permission.COMPLIANCE_VIEW, + Permission.COMPLIANCE_EXPORT, ], - UserRole.AUDITOR: [ # External audit support - Permission.HOST_READ, Permission.CONTENT_READ, Permission.SCAN_READ, - Permission.RESULTS_READ, Permission.RESULTS_READ_ALL, Permission.REPORTS_EXPORT, - Permission.AUDIT_READ, Permission.COMPLIANCE_VIEW, Permission.COMPLIANCE_EXPORT + Permission.HOST_READ, + Permission.CONTENT_READ, + Permission.SCAN_READ, + Permission.RESULTS_READ, + Permission.RESULTS_READ_ALL, + Permission.REPORTS_EXPORT, + Permission.AUDIT_READ, + Permission.COMPLIANCE_VIEW, + Permission.COMPLIANCE_EXPORT, ], - UserRole.GUEST: [ # Very limited access Permission.HOST_READ, # Read-only access to assigned hosts Permission.RESULTS_READ, # Read-only access to assigned results - Permission.COMPLIANCE_VIEW # Basic compliance viewing - ] + Permission.COMPLIANCE_VIEW, # Basic compliance viewing + ], } class RBACManager: """Role-Based Access Control Manager""" - + @staticmethod def get_role_permissions(role: UserRole) -> Set[Permission]: """Get all permissions for a role""" return set(ROLE_PERMISSIONS.get(role, [])) - + @staticmethod def has_permission(user_role: UserRole, required_permission: Permission) -> bool: """Check if a role has a specific permission""" role_permissions = RBACManager.get_role_permissions(user_role) return required_permission in role_permissions - + @staticmethod def has_any_permission(user_role: UserRole, required_permissions: List[Permission]) -> bool: """Check if a role has any of the required permissions""" role_permissions = RBACManager.get_role_permissions(user_role) return any(perm in role_permissions for perm in required_permissions) - + @staticmethod def has_all_permissions(user_role: UserRole, required_permissions: List[Permission]) -> bool: """Check if a role has all required permissions""" role_permissions = RBACManager.get_role_permissions(user_role) return all(perm in role_permissions for perm in required_permissions) - + @staticmethod def can_access_resource(user_role: UserRole, resource_type: str, action: str) -> bool: """Check if a role can perform an action on a resource type""" @@ -176,112 +223,122 @@ def can_access_resource(user_role: UserRole, resource_type: str, action: str) -> "create": Permission.USER_CREATE, "read": Permission.USER_READ, "update": Permission.USER_UPDATE, - "delete": Permission.USER_DELETE + "delete": Permission.USER_DELETE, }, "host": { "create": Permission.HOST_CREATE, "read": Permission.HOST_READ, "update": Permission.HOST_UPDATE, - "delete": Permission.HOST_DELETE + "delete": Permission.HOST_DELETE, }, "scan": { "create": Permission.SCAN_CREATE, "read": Permission.SCAN_READ, "update": Permission.SCAN_UPDATE, "delete": Permission.SCAN_DELETE, - "execute": Permission.SCAN_EXECUTE + "execute": Permission.SCAN_EXECUTE, }, "content": { "create": Permission.CONTENT_CREATE, "read": Permission.CONTENT_READ, "update": Permission.CONTENT_UPDATE, - "delete": Permission.CONTENT_DELETE + "delete": Permission.CONTENT_DELETE, }, - "audit": { - "read": Permission.AUDIT_READ - } + "audit": {"read": Permission.AUDIT_READ}, } - + if resource_type not in permission_map or action not in permission_map[resource_type]: return False - + required_permission = permission_map[resource_type][action] return RBACManager.has_permission(user_role, required_permission) def require_permission(permission: Permission): """Decorator to require a specific permission""" + def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): # Get current user from kwargs (injected by get_current_user dependency) - current_user = kwargs.get('current_user') + current_user = kwargs.get("current_user") if not current_user: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authentication required" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required" ) - - user_role = UserRole(current_user.get('role', 'guest')) + + user_role = UserRole(current_user.get("role", "guest")) if not RBACManager.has_permission(user_role, permission): - logger.warning(f"User {current_user.get('username')} with role {user_role} attempted to access {permission}") + logger.warning( + f"User {current_user.get('username')} with role {user_role} attempted to access {permission}" + ) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Insufficient permissions. Required: {permission.value}" + detail=f"Insufficient permissions. Required: {permission.value}", ) - + return await func(*args, **kwargs) + return wrapper + return decorator def require_any_permission(permissions: List[Permission]): """Decorator to require any of the specified permissions""" + def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): - current_user = kwargs.get('current_user') + current_user = kwargs.get("current_user") if not current_user: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authentication required" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required" ) - - user_role = UserRole(current_user.get('role', 'guest')) + + user_role = UserRole(current_user.get("role", "guest")) if not RBACManager.has_any_permission(user_role, permissions): - logger.warning(f"User {current_user.get('username')} with role {user_role} attempted to access {[p.value for p in permissions]}") + logger.warning( + f"User {current_user.get('username')} with role {user_role} attempted to access {[p.value for p in permissions]}" + ) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Insufficient permissions. Required one of: {[p.value for p in permissions]}" + detail=f"Insufficient permissions. Required one of: {[p.value for p in permissions]}", ) - + return await func(*args, **kwargs) + return wrapper + return decorator def require_role(required_roles: List[UserRole]): """Decorator to require specific roles""" + def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): - current_user = kwargs.get('current_user') + current_user = kwargs.get("current_user") if not current_user: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authentication required" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required" ) - - user_role = UserRole(current_user.get('role', 'guest')) + + user_role = UserRole(current_user.get("role", "guest")) if user_role not in required_roles: - logger.warning(f"User {current_user.get('username')} with role {user_role} attempted to access endpoint requiring {[r.value for r in required_roles]}") + logger.warning( + f"User {current_user.get('username')} with role {user_role} attempted to access endpoint requiring {[r.value for r in required_roles]}" + ) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Insufficient role. Required one of: {[r.value for r in required_roles]}" + detail=f"Insufficient role. Required one of: {[r.value for r in required_roles]}", ) - + return await func(*args, **kwargs) + return wrapper + return decorator @@ -298,16 +355,12 @@ def require_super_admin(): def require_analyst_or_above(): """Require analyst level or above""" - return require_role([ - UserRole.SUPER_ADMIN, - UserRole.SECURITY_ADMIN, - UserRole.SECURITY_ANALYST - ]) + return require_role([UserRole.SUPER_ADMIN, UserRole.SECURITY_ADMIN, UserRole.SECURITY_ANALYST]) def check_permission(user_role: str, resource_type: str, action: str): """Check if a user role has permission to perform an action on a resource. - + For API keys, we'll allow super_admin and security_admin to manage them. """ # Special handling for API keys @@ -316,28 +369,27 @@ def check_permission(user_role: str, resource_type: str, action: str): if UserRole(user_role) not in allowed_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Only administrators can manage API keys" + detail=f"Only administrators can manage API keys", ) return - + # Use existing permission check for other resources if not RBACManager.can_access_resource(UserRole(user_role), resource_type, action): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Insufficient permissions to {action} {resource_type}" + detail=f"Insufficient permissions to {action} {resource_type}", ) -async def check_permission_async(current_user: dict, required_permission: Permission, db: Any = None): +def check_permission_async(current_user: dict, required_permission: Permission, db: Any = None): """Async permission check for specific permissions""" if not current_user: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Authentication required" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required" ) - + user_roles = current_user.get("roles", [current_user.get("role", "guest")]) - + # Check if any of the user's roles has the required permission for role_name in user_roles: try: @@ -347,10 +399,12 @@ async def check_permission_async(current_user: dict, required_permission: Permis except ValueError: logger.warning(f"Unknown role: {role_name}") continue - + # If no role has permission, log and raise error - logger.warning(f"User {current_user.get('username')} attempted to access {required_permission.value}") + logger.warning( + f"User {current_user.get('username')} attempted to access {required_permission.value}" + ) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Insufficient permissions. Required: {required_permission.value}" - ) \ No newline at end of file + detail=f"Insufficient permissions. Required: {required_permission.value}", + ) diff --git a/backend/app/routes/__init__.py b/backend/app/routes/__init__.py index 54581e31..b24e4822 100644 --- a/backend/app/routes/__init__.py +++ b/backend/app/routes/__init__.py @@ -3,4 +3,4 @@ """ # Import compliance module to make it available -from . import compliance \ No newline at end of file +from . import compliance diff --git a/backend/app/routes/audit.py b/backend/app/routes/audit.py index 316c102b..421d2194 100644 --- a/backend/app/routes/audit.py +++ b/backend/app/routes/audit.py @@ -1,6 +1,7 @@ """ Audit Log API Routes for OView Dashboard """ + from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session from sqlalchemy import text, desc, and_, or_ @@ -18,6 +19,7 @@ from pydantic import BaseModel + class AuditEventResponse(BaseModel): id: int user_id: Optional[int] @@ -31,12 +33,14 @@ class AuditEventResponse(BaseModel): timestamp: str severity: str + class AuditEventsResponse(BaseModel): events: List[AuditEventResponse] total: int page: int limit: int + class AuditStatsResponse(BaseModel): total_events: int login_attempts: int @@ -47,6 +51,7 @@ class AuditStatsResponse(BaseModel): unique_users: int unique_ips: int + @router.get("/events", response_model=AuditEventsResponse) async def get_audit_events( page: int = Query(1, ge=1), @@ -59,17 +64,19 @@ async def get_audit_events( date_from: Optional[datetime] = Query(None), date_to: Optional[datetime] = Query(None), db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Get audit events with filtering and pagination """ try: # Check permissions - user_role = UserRole(current_user.get('role', 'guest')) - if not RBACManager.can_access_resource(user_role, 'audit', 'read'): - raise HTTPException(status_code=403, detail="Insufficient permissions to view audit logs") - + user_role = UserRole(current_user.get("role", "guest")) + if not RBACManager.can_access_resource(user_role, "audit", "read"): + raise HTTPException( + status_code=403, detail="Insufficient permissions to view audit logs" + ) + # Build base query query = """ SELECT al.*, u.username @@ -77,107 +84,109 @@ async def get_audit_events( LEFT JOIN users u ON al.user_id = u.id WHERE 1=1 """ - + params = {} - + # Add filters if search: query += " AND (al.action ILIKE :search OR al.details ILIKE :search OR al.ip_address ILIKE :search OR u.username ILIKE :search)" - params['search'] = f"%{search}%" - + params["search"] = f"%{search}%" + if action: query += " AND al.action ILIKE :action" - params['action'] = f"%{action}%" - + params["action"] = f"%{action}%" + if resource_type: query += " AND al.resource_type = :resource_type" - params['resource_type'] = resource_type - + params["resource_type"] = resource_type + if user: query += " AND u.username ILIKE :user" - params['user'] = f"%{user}%" - + params["user"] = f"%{user}%" + if date_from: query += " AND al.timestamp >= :date_from" - params['date_from'] = date_from - + params["date_from"] = date_from + if date_to: query += " AND al.timestamp <= :date_to" - params['date_to'] = date_to - + params["date_to"] = date_to + # Get total count count_query = f"SELECT COUNT(*) as total FROM ({query}) as subquery" count_result = db.execute(text(count_query), params) total = count_result.fetchone().total - + # Add ordering and pagination query += " ORDER BY al.timestamp DESC" query += " LIMIT :limit OFFSET :offset" - params['limit'] = limit - params['offset'] = (page - 1) * limit - + params["limit"] = limit + params["offset"] = (page - 1) * limit + # Execute query result = db.execute(text(query), params) - + events = [] for row in result: # Determine severity based on action - severity = 'info' - if 'FAILED' in row.action or 'ERROR' in row.action: - severity = 'error' - elif 'SECURITY' in row.action or 'UNAUTHORIZED' in row.action: - severity = 'warning' - elif 'ADMIN' in row.action or 'DELETE' in row.action: - severity = 'warning' - - events.append(AuditEventResponse( - id=row.id, - user_id=row.user_id, - username=row.username, - action=row.action, - resource_type=row.resource_type, - resource_id=row.resource_id, - ip_address=row.ip_address, - user_agent=row.user_agent, - details=row.details, - timestamp=row.timestamp.isoformat() if row.timestamp else None, - severity=severity - )) - - return AuditEventsResponse( - events=events, - total=total, - page=page, - limit=limit - ) - + severity = "info" + if "FAILED" in row.action or "ERROR" in row.action: + severity = "error" + elif "SECURITY" in row.action or "UNAUTHORIZED" in row.action: + severity = "warning" + elif "ADMIN" in row.action or "DELETE" in row.action: + severity = "warning" + + events.append( + AuditEventResponse( + id=row.id, + user_id=row.user_id, + username=row.username, + action=row.action, + resource_type=row.resource_type, + resource_id=row.resource_id, + ip_address=row.ip_address, + user_agent=row.user_agent, + details=row.details, + timestamp=row.timestamp.isoformat() if row.timestamp else None, + severity=severity, + ) + ) + + return AuditEventsResponse(events=events, total=total, page=page, limit=limit) + except HTTPException: raise except Exception as e: logger.error(f"Error retrieving audit events: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve audit events") + @router.get("/stats", response_model=AuditStatsResponse) async def get_audit_stats( days: int = Query(30, ge=1, le=365), db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Get audit statistics for the dashboard """ try: # Check permissions - user_role = UserRole(current_user.get('role', 'guest')) - if not RBACManager.can_access_resource(user_role, 'audit', 'read'): - raise HTTPException(status_code=403, detail="Insufficient permissions to view audit logs") - + user_role = UserRole(current_user.get("role", "guest")) + if not RBACManager.can_access_resource(user_role, "audit", "read"): + raise HTTPException( + status_code=403, detail="Insufficient permissions to view audit logs" + ) + # Calculate date range from datetime import datetime, timedelta + date_from = datetime.utcnow() - timedelta(days=days) - + # Get statistics - stats_query = text(""" + stats_query = text( + """ SELECT COUNT(*) as total_events, COUNT(CASE WHEN action LIKE '%LOGIN%' THEN 1 END) as login_attempts, @@ -189,11 +198,12 @@ async def get_audit_stats( COUNT(DISTINCT ip_address) as unique_ips FROM audit_logs WHERE timestamp >= :date_from - """) - + """ + ) + result = db.execute(stats_query, {"date_from": date_from}) row = result.fetchone() - + return AuditStatsResponse( total_events=row.total_events or 0, login_attempts=row.login_attempts or 0, @@ -202,15 +212,16 @@ async def get_audit_stats( admin_actions=row.admin_actions or 0, security_events=row.security_events or 0, unique_users=row.unique_users or 0, - unique_ips=row.unique_ips or 0 + unique_ips=row.unique_ips or 0, ) - + except HTTPException: raise except Exception as e: logger.error(f"Error retrieving audit stats: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve audit statistics") + @router.post("/log") async def create_audit_log( action: str, @@ -218,41 +229,47 @@ async def create_audit_log( resource_id: Optional[str] = None, details: Optional[str] = None, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Create a new audit log entry (for internal use) """ try: from datetime import datetime - + # This would typically be called internally by the system # For now, we'll create a simple log entry - insert_query = text(""" + insert_query = text( + """ INSERT INTO audit_logs (user_id, action, resource_type, resource_id, ip_address, details, timestamp) VALUES (:user_id, :action, :resource_type, :resource_id, :ip_address, :details, :timestamp) - """) - - db.execute(insert_query, { - "user_id": current_user.get('id'), - "action": action, - "resource_type": resource_type, - "resource_id": resource_id, - "ip_address": "127.0.0.1", # This should come from request - "details": details, - "timestamp": datetime.utcnow() - }) - + """ + ) + + db.execute( + insert_query, + { + "user_id": current_user.get("id"), + "action": action, + "resource_type": resource_type, + "resource_id": resource_id, + "ip_address": "127.0.0.1", # This should come from request + "details": details, + "timestamp": datetime.utcnow(), + }, + ) + db.commit() return {"message": "Audit log created successfully"} - + except Exception as e: logger.error(f"Error creating audit log: {e}") db.rollback() raise HTTPException(status_code=500, detail="Failed to create audit log") + # Helper function to create audit logs from middleware -async def log_audit_event( +def log_audit_event( db: Session, user_id: Optional[int], action: str, @@ -260,30 +277,35 @@ async def log_audit_event( resource_id: Optional[str], ip_address: str, user_agent: Optional[str], - details: Optional[str] + details: Optional[str], ): """ Helper function to create audit log entries from middleware """ try: - insert_query = text(""" + insert_query = text( + """ INSERT INTO audit_logs (user_id, action, resource_type, resource_id, ip_address, user_agent, details, timestamp) VALUES (:user_id, :action, :resource_type, :resource_id, :ip_address, :user_agent, :details, :timestamp) - """) - - db.execute(insert_query, { - "user_id": user_id, - "action": action, - "resource_type": resource_type, - "resource_id": resource_id, - "ip_address": ip_address, - "user_agent": user_agent, - "details": details, - "timestamp": datetime.utcnow() - }) - + """ + ) + + db.execute( + insert_query, + { + "user_id": user_id, + "action": action, + "resource_type": resource_type, + "resource_id": resource_id, + "ip_address": ip_address, + "user_agent": user_agent, + "details": details, + "timestamp": datetime.utcnow(), + }, + ) + db.commit() - + except Exception as e: logger.error(f"Error creating audit log entry: {e}") - db.rollback() \ No newline at end of file + db.rollback() diff --git a/backend/app/routes/auth.py b/backend/app/routes/auth.py index dc080e3f..79727ace 100644 --- a/backend/app/routes/auth.py +++ b/backend/app/routes/auth.py @@ -1,6 +1,7 @@ """ Authentication Routes - FIPS Compliant """ + from fastapi import APIRouter, HTTPException, Depends, status, Request from fastapi.security import HTTPBearer from pydantic import BaseModel, EmailStr @@ -60,23 +61,28 @@ async def login(request: LoginRequest, http_request: Request, db: Session = Depe """Authenticate user with username/password and optional MFA""" client_ip = get_client_ip(http_request) user_agent = http_request.headers.get("user-agent") - + try: # Get user from database - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, username, email, hashed_password, role, is_active, failed_login_attempts, locked_until, last_login FROM users WHERE username = :username - """), {"username": request.username}) - + """ + ), + {"username": request.username}, + ) + user = result.fetchone() if not user: # Log to file and database audit_logger.log_security_event( - "AUTH_FAILURE", - f"Login attempt with non-existent username: {request.username}", - client_ip + "AUTH_FAILURE", + f"Login attempt with non-existent username: {request.username}", + client_ip, ) await log_login_event( db=db, @@ -85,19 +91,18 @@ async def login(request: LoginRequest, http_request: Request, db: Session = Depe success=False, ip_address=client_ip, user_agent=user_agent, - failure_reason="Non-existent username" + failure_reason="Non-existent username", ) raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid username or password" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password" ) - + # Check if user is active if not user.is_active: audit_logger.log_security_event( - "AUTH_FAILURE", - f"Login attempt with inactive account: {request.username}", - client_ip + "AUTH_FAILURE", + f"Login attempt with inactive account: {request.username}", + client_ip, ) await log_login_event( db=db, @@ -106,19 +111,16 @@ async def login(request: LoginRequest, http_request: Request, db: Session = Depe success=False, ip_address=client_ip, user_agent=user_agent, - failure_reason="Account deactivated" + failure_reason="Account deactivated", ) raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Account is deactivated" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Account is deactivated" ) - + # Check if account is locked if user.locked_until and user.locked_until > datetime.utcnow(): audit_logger.log_security_event( - "AUTH_FAILURE", - f"Login attempt with locked account: {request.username}", - client_ip + "AUTH_FAILURE", f"Login attempt with locked account: {request.username}", client_ip ) await log_login_event( db=db, @@ -127,38 +129,38 @@ async def login(request: LoginRequest, http_request: Request, db: Session = Depe success=False, ip_address=client_ip, user_agent=user_agent, - failure_reason="Account locked" + failure_reason="Account locked", ) raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Account is temporarily locked" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Account is temporarily locked" ) - + # Verify password if not pwd_context.verify(request.password, user.hashed_password): # Increment failed login attempts failed_attempts = user.failed_login_attempts + 1 locked_until = None - + # Lock account after 5 failed attempts for 30 minutes if failed_attempts >= 5: locked_until = datetime.utcnow() + timedelta(minutes=30) - - db.execute(text(""" + + db.execute( + text( + """ UPDATE users SET failed_login_attempts = :attempts, locked_until = :locked_until WHERE id = :user_id - """), { - "attempts": failed_attempts, - "locked_until": locked_until, - "user_id": user.id - }) + """ + ), + {"attempts": failed_attempts, "locked_until": locked_until, "user_id": user.id}, + ) db.commit() - + audit_logger.log_security_event( - "AUTH_FAILURE", - f"Invalid password for user: {request.username} (attempt {failed_attempts})", - client_ip + "AUTH_FAILURE", + f"Invalid password for user: {request.username} (attempt {failed_attempts})", + client_ip, ) await log_login_event( db=db, @@ -167,41 +169,43 @@ async def login(request: LoginRequest, http_request: Request, db: Session = Depe success=False, ip_address=client_ip, user_agent=user_agent, - failure_reason=f"Invalid password (attempt {failed_attempts})" + failure_reason=f"Invalid password (attempt {failed_attempts})", ) raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid username or password" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password" ) - + # Skip MFA for now since columns don't exist in current schema # TODO: Add MFA support after running proper migrations - + # Reset failed login attempts and update last login - db.execute(text(""" + db.execute( + text( + """ UPDATE users SET failed_login_attempts = 0, locked_until = NULL, last_login = CURRENT_TIMESTAMP WHERE id = :user_id - """), {"user_id": user.id}) + """ + ), + {"user_id": user.id}, + ) db.commit() - + user_data = { "sub": user.username, # Standard JWT subject field "id": user.id, "username": user.username, "email": user.email, "role": user.role, - "mfa_enabled": False # MFA not available in current schema + "mfa_enabled": False, # MFA not available in current schema } - + # Generate tokens access_token = jwt_manager.create_access_token(user_data) refresh_token = jwt_manager.create_refresh_token(user_data) - + audit_logger.log_security_event( - "LOGIN_SUCCESS", - f"User {request.username} logged in successfully", - client_ip + "LOGIN_SUCCESS", f"User {request.username} logged in successfully", client_ip ) await log_login_event( db=db, @@ -209,16 +213,16 @@ async def login(request: LoginRequest, http_request: Request, db: Session = Depe user_id=user.id, success=True, ip_address=client_ip, - user_agent=user_agent + user_agent=user_agent, ) - + return LoginResponse( access_token=access_token, refresh_token=refresh_token, expires_in=settings.access_token_expire_minutes * 60, - user=user_data + user=user_data, ) - + except HTTPException: # Re-raise HTTP exceptions (already logged above) raise @@ -227,7 +231,7 @@ async def login(request: LoginRequest, http_request: Request, db: Session = Depe audit_logger.log_security_event( "LOGIN_FAILURE", f"System error during login for {request.username}: {str(e)}", - client_ip + client_ip, ) await log_login_event( db=db, @@ -236,12 +240,9 @@ async def login(request: LoginRequest, http_request: Request, db: Session = Depe success=False, ip_address=client_ip, user_agent=user_agent, - failure_reason=f"System error: {str(e)}" - ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid credentials" + failure_reason=f"System error: {str(e)}", ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") @router.post("/register", response_model=LoginResponse) @@ -249,68 +250,74 @@ async def register(request: RegisterRequest, db: Session = Depends(get_db)): """Register a new user (guest role by default)""" try: # Check if username or email already exists - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id FROM users WHERE username = :username OR email = :email - """), {"username": request.username, "email": request.email}) - + """ + ), + {"username": request.username, "email": request.email}, + ) + if result.fetchone(): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Username or email already exists" + status_code=status.HTTP_400_BAD_REQUEST, detail="Username or email already exists" ) - + # Hash password hashed_password = pwd_context.hash(request.password) - + # Create user with guest role (or specified role if admin is creating) - result = db.execute(text(""" + result = db.execute( + text( + """ INSERT INTO users (username, email, hashed_password, role, is_active, created_at, failed_login_attempts) VALUES (:username, :email, :password, :role, true, CURRENT_TIMESTAMP, 0) RETURNING id - """), { - "username": request.username, - "email": request.email, - "password": hashed_password, - "role": request.role.value - }) - + """ + ), + { + "username": request.username, + "email": request.email, + "password": hashed_password, + "role": request.role.value, + }, + ) + user_id = result.fetchone().id db.commit() - + user_data = { "sub": request.username, # Standard JWT subject field "id": user_id, "username": request.username, "email": request.email, "role": request.role.value, - "mfa_enabled": False + "mfa_enabled": False, } - + # Generate tokens for immediate login access_token = jwt_manager.create_access_token(user_data) refresh_token = jwt_manager.create_refresh_token(user_data) - + audit_logger.log_security_event( - "USER_REGISTER", - f"New user registered: {request.username}", - "127.0.0.1" + "USER_REGISTER", f"New user registered: {request.username}", "127.0.0.1" ) - + return LoginResponse( access_token=access_token, refresh_token=refresh_token, expires_in=settings.access_token_expire_minutes * 60, - user=user_data + user=user_data, ) - + except HTTPException: raise except Exception as e: logger.error(f"Registration failed for {request.username}: {e}") db.rollback() raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Registration failed" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Registration failed" ) @@ -320,29 +327,32 @@ async def refresh_token(request: RefreshRequest, db: Session = Depends(get_db)): try: # Validate refresh token and get user user_data = jwt_manager.validate_refresh_token(request.refresh_token) - + # Get fresh user data from database to ensure we have latest info username = user_data.get("sub") or user_data.get("username") if not username: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token data" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token data" ) - + # Get updated user info from database - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, username, email, role, is_active, mfa_enabled FROM users WHERE username = :username - """), {"username": username}) - + """ + ), + {"username": username}, + ) + user = result.fetchone() if not user or not user.is_active: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="User not found or inactive" + status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive" ) - + # Create fresh user data for new token fresh_user_data = { "sub": user.username, @@ -350,32 +360,29 @@ async def refresh_token(request: RefreshRequest, db: Session = Depends(get_db)): "username": user.username, "email": user.email, "role": user.role, - "mfa_enabled": bool(user.mfa_enabled) + "mfa_enabled": bool(user.mfa_enabled), } - + # Generate new access token with fresh data access_token = jwt_manager.create_access_token(fresh_user_data) - + # Log the refresh event audit_logger.log_security_event( - "TOKEN_REFRESH", - f"Token refreshed for user {username}", - "system" + "TOKEN_REFRESH", f"Token refreshed for user {username}", "system" ) - + return { "access_token": access_token, "token_type": "bearer", - "expires_in": settings.access_token_expire_minutes * 60 + "expires_in": settings.access_token_expire_minutes * 60, } - + except HTTPException: raise except Exception as e: logger.error(f"Token refresh failed: {e}") raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid refresh token" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token" ) @@ -384,19 +391,14 @@ async def logout(token: str = Depends(security)): """Logout user and invalidate tokens""" try: # In production, add token to blacklist - audit_logger.log_security_event( - "LOGOUT", - "User logged out", - "127.0.0.1" - ) - + audit_logger.log_security_event("LOGOUT", "User logged out", "127.0.0.1") + return {"message": "Successfully logged out"} - + except Exception as e: logger.error(f"Logout failed: {e}") raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Logout failed" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Logout failed" ) @@ -406,19 +408,14 @@ async def get_current_user(token: str = Depends(security)): try: # Validate token and get user data user_data = jwt_manager.validate_access_token(token.credentials) - + from ..rbac import RBACManager, UserRole - user_role = UserRole(user_data.get("role", "guest")) + + user_role = UserRole(user_data.get("role", "guest")) permissions = [p.value for p in RBACManager.get_role_permissions(user_role)] - - return { - "user": user_data, - "permissions": permissions - } - + + return {"user": user_data, "permissions": permissions} + except Exception as e: logger.error(f"Failed to get current user: {e}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token" - ) \ No newline at end of file + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") diff --git a/backend/app/routes/authorization.py b/backend/app/routes/authorization.py index 110621ad..26e4f408 100644 --- a/backend/app/routes/authorization.py +++ b/backend/app/routes/authorization.py @@ -9,6 +9,7 @@ Design by Emily (Security Engineer) & Implementation by Daniel (Backend Engineer) """ + import logging from typing import Dict, List, Optional, Set from datetime import datetime, timedelta @@ -22,8 +23,13 @@ from ..rbac import require_permission, require_admin, Permission from ..services.authorization_service import get_authorization_service, AuthorizationService from ..models.authorization_models import ( - ResourceType, ActionType, PermissionEffect, ResourceIdentifier, - AuthorizationContext, HostPermission, HostGroupPermission + ResourceType, + ActionType, + PermissionEffect, + ResourceIdentifier, + AuthorizationContext, + HostPermission, + HostGroupPermission, ) logger = logging.getLogger(__name__) @@ -33,14 +39,18 @@ # Request/Response Models + class PermissionGrantRequest(BaseModel): """Request to grant permission to a user/group/role""" + user_id: Optional[str] = None group_id: Optional[str] = None role_name: Optional[str] = None host_id: Optional[str] = None host_group_id: Optional[str] = None - actions: Set[str] = Field(..., description="List of actions: read, write, execute, delete, manage, scan") + actions: Set[str] = Field( + ..., description="List of actions: read, write, execute, delete, manage, scan" + ) effect: str = Field(default="allow", description="Permission effect: allow or deny") expires_at: Optional[datetime] = None conditions: Dict = Field(default_factory=dict) @@ -48,6 +58,7 @@ class PermissionGrantRequest(BaseModel): class PermissionResponse(BaseModel): """Response containing permission details""" + id: str user_id: Optional[str] group_id: Optional[str] @@ -64,6 +75,7 @@ class PermissionResponse(BaseModel): class PermissionCheckRequest(BaseModel): """Request to check permissions for resources""" + user_id: Optional[str] = None # If not provided, uses current user resource_type: str = Field(..., description="Resource type: host, host_group, scan, etc.") resource_id: str = Field(..., description="Resource identifier") @@ -72,6 +84,7 @@ class PermissionCheckRequest(BaseModel): class PermissionCheckResponse(BaseModel): """Response for permission check""" + allowed: bool decision: str reason: str @@ -82,6 +95,7 @@ class PermissionCheckResponse(BaseModel): class BulkPermissionCheckRequest(BaseModel): """Request for bulk permission checking""" + user_id: Optional[str] = None resources: List[Dict] = Field(..., description="List of {resource_type, resource_id} dicts") action: str @@ -90,6 +104,7 @@ class BulkPermissionCheckRequest(BaseModel): class BulkPermissionCheckResponse(BaseModel): """Response for bulk permission check""" + overall_allowed: bool allowed_resources: List[Dict] denied_resources: List[Dict] @@ -99,16 +114,17 @@ class BulkPermissionCheckResponse(BaseModel): # Permission Management Endpoints + @router.post("/permissions/host", response_model=Dict) @require_permission(Permission.HOST_MANAGE_ACCESS) async def grant_host_permission( request: PermissionGrantRequest, current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Grant permission for a specific host - + SECURITY REQUIREMENT: Only users with HOST_MANAGE_ACCESS permission can grant host-level permissions to other users. """ @@ -116,26 +132,26 @@ async def grant_host_permission( if not request.host_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="host_id is required for host permissions" + detail="host_id is required for host permissions", ) - + # Validate that exactly one subject is specified subject_count = sum(1 for x in [request.user_id, request.group_id, request.role_name] if x) if subject_count != 1: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Exactly one of user_id, group_id, or role_name must be specified" + detail="Exactly one of user_id, group_id, or role_name must be specified", ) - + # Validate actions valid_actions = {"read", "write", "execute", "delete", "manage", "scan", "export"} invalid_actions = set(request.actions) - valid_actions if invalid_actions: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid actions: {invalid_actions}. Valid actions: {valid_actions}" + detail=f"Invalid actions: {invalid_actions}. Valid actions: {valid_actions}", ) - + # Convert string actions to ActionType enums action_types = set() action_map = { @@ -145,12 +161,12 @@ async def grant_host_permission( "delete": ActionType.DELETE, "manage": ActionType.MANAGE, "scan": ActionType.SCAN, - "export": ActionType.EXPORT + "export": ActionType.EXPORT, } - + for action in request.actions: action_types.add(action_map[action]) - + # Grant permission using authorization service auth_service = get_authorization_service(db) permission_id = await auth_service.grant_host_permission( @@ -161,26 +177,28 @@ async def grant_host_permission( actions=action_types, granted_by=current_user["id"], expires_at=request.expires_at, - conditions=request.conditions + conditions=request.conditions, + ) + + logger.info( + f"Host permission granted by {current_user['username']}: {permission_id} for host {request.host_id}" ) - - logger.info(f"Host permission granted by {current_user['username']}: {permission_id} for host {request.host_id}") - + return { "success": True, "permission_id": permission_id, "message": f"Permission granted for host {request.host_id}", "granted_by": current_user["username"], - "granted_at": datetime.utcnow().isoformat() + "granted_at": datetime.utcnow().isoformat(), } - + except HTTPException: raise except Exception as e: logger.error(f"Error granting host permission: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to grant permission: {str(e)}" + detail=f"Failed to grant permission: {str(e)}", ) @@ -189,7 +207,7 @@ async def grant_host_permission( async def revoke_permission( permission_id: str, current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Revoke a specific permission @@ -197,46 +215,46 @@ async def revoke_permission( try: auth_service = get_authorization_service(db) success = await auth_service.revoke_permission(permission_id) - + if not success: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail=f"Permission {permission_id} not found" + detail=f"Permission {permission_id} not found", ) - + logger.info(f"Permission {permission_id} revoked by {current_user['username']}") - + return { "success": True, "message": f"Permission {permission_id} revoked", "revoked_by": current_user["username"], - "revoked_at": datetime.utcnow().isoformat() + "revoked_at": datetime.utcnow().isoformat(), } - + except HTTPException: raise except Exception as e: logger.error(f"Error revoking permission: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to revoke permission: {str(e)}" + detail=f"Failed to revoke permission: {str(e)}", ) @router.get("/permissions/host/{host_id}") @require_permission(Permission.HOST_READ) async def get_host_permissions( - host_id: str, - current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + host_id: str, current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """ Get all permissions for a specific host """ try: from sqlalchemy import text - - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT hp.id, hp.user_id, hp.group_id, hp.role_name, hp.host_id, hp.actions, hp.effect, hp.conditions, hp.granted_by, hp.granted_at, hp.expires_at, hp.is_active, @@ -249,52 +267,59 @@ async def get_host_permissions( LEFT JOIN user_groups ug ON hp.group_id = ug.id WHERE hp.host_id = :host_id AND hp.is_active = true ORDER BY hp.granted_at DESC - """), {"host_id": host_id}) - + """ + ), + {"host_id": host_id}, + ) + permissions = [] for row in result: import json + actions = json.loads(row.actions) if isinstance(row.actions, str) else row.actions - - permissions.append({ - "id": row.id, - "user_id": row.user_id, - "username": row.target_username, - "group_id": row.group_id, - "group_name": row.target_group_name, - "role_name": row.role_name, - "host_id": row.host_id, - "actions": actions, - "effect": row.effect, - "conditions": row.conditions, - "granted_by": row.granted_by, - "granted_by_username": row.granted_by_username, - "granted_at": row.granted_at.isoformat() if row.granted_at else None, - "expires_at": row.expires_at.isoformat() if row.expires_at else None, - "is_active": row.is_active - }) - + + permissions.append( + { + "id": row.id, + "user_id": row.user_id, + "username": row.target_username, + "group_id": row.group_id, + "group_name": row.target_group_name, + "role_name": row.role_name, + "host_id": row.host_id, + "actions": actions, + "effect": row.effect, + "conditions": row.conditions, + "granted_by": row.granted_by, + "granted_by_username": row.granted_by_username, + "granted_at": row.granted_at.isoformat() if row.granted_at else None, + "expires_at": row.expires_at.isoformat() if row.expires_at else None, + "is_active": row.is_active, + } + ) + return { "host_id": host_id, "permissions": permissions, - "total_permissions": len(permissions) + "total_permissions": len(permissions), } - + except Exception as e: logger.error(f"Error getting host permissions: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to get host permissions: {str(e)}" + detail=f"Failed to get host permissions: {str(e)}", ) # Permission Checking Endpoints + @router.post("/check", response_model=PermissionCheckResponse) async def check_permission( request: PermissionCheckRequest, current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Check if a user has permission to perform an action on a resource @@ -302,7 +327,7 @@ async def check_permission( try: # Use current user if no user_id specified user_id = request.user_id or current_user["id"] - + # Validate resource type and action try: resource_type = ResourceType(request.resource_type) @@ -310,35 +335,32 @@ async def check_permission( except ValueError as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid resource_type or action: {str(e)}" + detail=f"Invalid resource_type or action: {str(e)}", ) - + # Create resource identifier - resource = ResourceIdentifier( - resource_type=resource_type, - resource_id=request.resource_id - ) - + resource = ResourceIdentifier(resource_type=resource_type, resource_id=request.resource_id) + # Perform authorization check auth_service = get_authorization_service(db) result = await auth_service.check_permission(user_id, resource, action) - + return PermissionCheckResponse( allowed=(result.decision.value == "allow"), decision=result.decision.value, reason=result.reason, - evaluated_policies=[p.get('id', 'unknown') for p in result.applied_policies], + evaluated_policies=[p.get("id", "unknown") for p in result.applied_policies], evaluation_time_ms=result.evaluation_time_ms, - timestamp=result.timestamp + timestamp=result.timestamp, ) - + except HTTPException: raise except Exception as e: logger.error(f"Error checking permission: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Permission check failed: {str(e)}" + detail=f"Permission check failed: {str(e)}", ) @@ -346,11 +368,11 @@ async def check_permission( async def check_bulk_permissions( request: BulkPermissionCheckRequest, current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Check permissions for multiple resources in bulk - + CRITICAL SECURITY IMPLEMENTATION: This endpoint demonstrates the fixed bulk authorization logic that prevents users from accessing resources they don't have permissions for. @@ -358,68 +380,69 @@ async def check_bulk_permissions( try: # Use current user if no user_id specified user_id = request.user_id or current_user["id"] - + # Validate and convert resources resources = [] for res_data in request.resources: try: resource = ResourceIdentifier( resource_type=ResourceType(res_data["resource_type"]), - resource_id=res_data["resource_id"] + resource_id=res_data["resource_id"], ) resources.append(resource) except (KeyError, ValueError) as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid resource data: {res_data}. Error: {str(e)}" + detail=f"Invalid resource data: {res_data}. Error: {str(e)}", ) - + try: action = ActionType(request.action) except ValueError: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid action: {request.action}" + status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid action: {request.action}" ) - + # Build authorization context auth_service = get_authorization_service(db) auth_context = await auth_service._build_user_context(user_id) - + # Perform bulk authorization check from ..models.authorization_models import BulkAuthorizationRequest + bulk_request = BulkAuthorizationRequest( user_id=user_id, resources=resources, action=action, context=auth_context, fail_fast=request.fail_fast, - parallel_evaluation=True + parallel_evaluation=True, ) - + result = await auth_service.check_bulk_permissions(bulk_request) - + # Format response allowed_resources = [ - { - "resource_type": res.resource_type.value, - "resource_id": res.resource_id - } + {"resource_type": res.resource_type.value, "resource_id": res.resource_id} for res in result.allowed_resources ] - + denied_resources = [ { "resource_type": res.resource_type.value, "resource_id": res.resource_id, "reason": next( - (r.reason for r in result.individual_results if r.resource.resource_id == res.resource_id), - "Access denied" - ) + ( + r.reason + for r in result.individual_results + if r.resource.resource_id == res.resource_id + ), + "Access denied", + ), } for res in result.denied_resources ] - + return BulkPermissionCheckResponse( overall_allowed=(result.overall_decision.value == "allow"), allowed_resources=allowed_resources, @@ -430,22 +453,23 @@ async def check_bulk_permissions( "allowed_count": len(allowed_resources), "denied_count": len(denied_resources), "cached_results": result.cached_results, - "fresh_evaluations": result.fresh_evaluations - } + "fresh_evaluations": result.fresh_evaluations, + }, ) - + except HTTPException: raise except Exception as e: logger.error(f"Error in bulk permission check: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Bulk permission check failed: {str(e)}" + detail=f"Bulk permission check failed: {str(e)}", ) # Administrative Endpoints + @router.get("/audit") @require_permission(Permission.AUDIT_READ) async def get_authorization_audit_log( @@ -457,42 +481,44 @@ async def get_authorization_audit_log( start_date: Optional[datetime] = Query(None, description="Filter from date"), end_date: Optional[datetime] = Query(None, description="Filter to date"), current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Get authorization audit log for security monitoring """ try: from sqlalchemy import text - + # Build WHERE clause conditions = ["1=1"] # Always true base condition params = {"limit": limit, "offset": offset} - + if user_id: conditions.append("user_id = :user_id") params["user_id"] = user_id - + if resource_type: conditions.append("resource_type = :resource_type") params["resource_type"] = resource_type - + if decision: conditions.append("decision = :decision") params["decision"] = decision - + if start_date: conditions.append("timestamp >= :start_date") params["start_date"] = start_date - + if end_date: conditions.append("timestamp <= :end_date") params["end_date"] = end_date - + where_clause = " AND ".join(conditions) - + # Get audit log entries - result = db.execute(text(f""" + result = db.execute( + text( + f""" SELECT id, event_type, user_id, resource_type, resource_id, action, decision, policies_evaluated, context, ip_address, user_agent, session_id, evaluation_time_ms, reason, risk_score, timestamp @@ -500,70 +526,83 @@ async def get_authorization_audit_log( WHERE {where_clause} ORDER BY timestamp DESC LIMIT :limit OFFSET :offset - """), params) - + """ + ), + params, + ) + audit_entries = [] for row in result: - audit_entries.append({ - "id": row.id, - "event_type": row.event_type, - "user_id": row.user_id, - "resource_type": row.resource_type, - "resource_id": row.resource_id, - "action": row.action, - "decision": row.decision, - "policies_evaluated": row.policies_evaluated.split(',') if row.policies_evaluated else [], - "context": row.context, - "ip_address": row.ip_address, - "user_agent": row.user_agent, - "session_id": row.session_id, - "evaluation_time_ms": row.evaluation_time_ms, - "reason": row.reason, - "risk_score": row.risk_score, - "timestamp": row.timestamp.isoformat() if row.timestamp else None - }) - + audit_entries.append( + { + "id": row.id, + "event_type": row.event_type, + "user_id": row.user_id, + "resource_type": row.resource_type, + "resource_id": row.resource_id, + "action": row.action, + "decision": row.decision, + "policies_evaluated": ( + row.policies_evaluated.split(",") if row.policies_evaluated else [] + ), + "context": row.context, + "ip_address": row.ip_address, + "user_agent": row.user_agent, + "session_id": row.session_id, + "evaluation_time_ms": row.evaluation_time_ms, + "reason": row.reason, + "risk_score": row.risk_score, + "timestamp": row.timestamp.isoformat() if row.timestamp else None, + } + ) + # Get total count - count_result = db.execute(text(f""" + count_result = db.execute( + text( + f""" SELECT COUNT(*) as total FROM authorization_audit_log WHERE {where_clause} - """), params) - + """ + ), + params, + ) + total_count = count_result.fetchone().total - + return { "audit_entries": audit_entries, "pagination": { "total": total_count, "limit": limit, "offset": offset, - "has_more": (offset + len(audit_entries)) < total_count - } + "has_more": (offset + len(audit_entries)) < total_count, + }, } - + except Exception as e: logger.error(f"Error getting authorization audit log: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to get audit log: {str(e)}" + detail=f"Failed to get audit log: {str(e)}", ) @router.get("/summary") @require_permission(Permission.SYSTEM_CONFIG) async def get_authorization_summary( - current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """ Get authorization system summary and statistics """ try: from sqlalchemy import text - + # Get permission statistics - perm_stats = db.execute(text(""" + perm_stats = db.execute( + text( + """ SELECT COUNT(*) as total_permissions, COUNT(CASE WHEN user_id IS NOT NULL THEN 1 END) as user_permissions, @@ -573,10 +612,14 @@ async def get_authorization_summary( COUNT(CASE WHEN effect = 'deny' THEN 1 END) as deny_permissions FROM host_permissions WHERE is_active = true - """)).fetchone() - + """ + ) + ).fetchone() + # Get recent audit statistics - audit_stats = db.execute(text(""" + audit_stats = db.execute( + text( + """ SELECT COUNT(*) as total_checks, COUNT(CASE WHEN decision = 'allow' THEN 1 END) as allowed_checks, @@ -585,10 +628,14 @@ async def get_authorization_summary( AVG(risk_score) as avg_risk_score FROM authorization_audit_log WHERE timestamp > NOW() - INTERVAL '24 hours' - """)).fetchone() - + """ + ) + ).fetchone() + # Get most active users - active_users = db.execute(text(""" + active_users = db.execute( + text( + """ SELECT u.username, COUNT(*) as check_count FROM authorization_audit_log aal JOIN users u ON aal.user_id = u.id @@ -596,8 +643,10 @@ async def get_authorization_summary( GROUP BY u.username ORDER BY check_count DESC LIMIT 10 - """)).fetchall() - + """ + ) + ).fetchall() + return { "permission_statistics": { "total_permissions": perm_stats.total_permissions or 0, @@ -605,24 +654,27 @@ async def get_authorization_summary( "group_permissions": perm_stats.group_permissions or 0, "role_permissions": perm_stats.role_permissions or 0, "temporary_permissions": perm_stats.temporary_permissions or 0, - "deny_permissions": perm_stats.deny_permissions or 0 + "deny_permissions": perm_stats.deny_permissions or 0, }, "recent_activity": { "total_checks_24h": audit_stats.total_checks or 0, "allowed_checks_24h": audit_stats.allowed_checks or 0, "denied_checks_24h": audit_stats.denied_checks or 0, - "avg_evaluation_time_ms": float(audit_stats.avg_evaluation_time) if audit_stats.avg_evaluation_time else 0, - "avg_risk_score": float(audit_stats.avg_risk_score) if audit_stats.avg_risk_score else 0 + "avg_evaluation_time_ms": ( + float(audit_stats.avg_evaluation_time) if audit_stats.avg_evaluation_time else 0 + ), + "avg_risk_score": ( + float(audit_stats.avg_risk_score) if audit_stats.avg_risk_score else 0 + ), }, "most_active_users": [ - {"username": row.username, "check_count": row.check_count} - for row in active_users - ] + {"username": row.username, "check_count": row.check_count} for row in active_users + ], } - + except Exception as e: logger.error(f"Error getting authorization summary: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to get authorization summary: {str(e)}" - ) \ No newline at end of file + detail=f"Failed to get authorization summary: {str(e)}", + ) diff --git a/backend/app/routes/automated_fixes.py b/backend/app/routes/automated_fixes.py index 4ab55b88..f18acc7f 100644 --- a/backend/app/routes/automated_fixes.py +++ b/backend/app/routes/automated_fixes.py @@ -37,6 +37,7 @@ class FixEvaluationRequest(BaseModel): """Request to evaluate automated fix options""" + legacy_fixes: List[Dict[str, Any]] target_host: str context: Optional[Dict[str, Any]] = Field(default_factory=dict) @@ -44,6 +45,7 @@ class FixEvaluationRequest(BaseModel): class FixExecutionRequest(BaseModel): """Request to execute an automated fix""" + fix_id: str secure_command_id: str parameters: Dict[str, Any] = Field(default_factory=dict) @@ -53,11 +55,13 @@ class FixExecutionRequest(BaseModel): class FixApprovalRequest(BaseModel): """Request to approve a pending fix""" + approval_reason: str = Field(min_length=10, max_length=500) class FixRollbackRequest(BaseModel): """Request to rollback a fix""" + rollback_reason: str = Field(min_length=10, max_length=500) @@ -65,17 +69,17 @@ class FixRollbackRequest(BaseModel): async def evaluate_fix_options( request: FixEvaluationRequest, current_user: dict = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_async_db), ): """ Evaluate legacy automated fixes and convert to secure options - + Requires: scan:read permission """ try: # Check permissions await check_permission_async(current_user, Permission.SCAN_READ, db) - + # Convert legacy fixes to AutomatedFix objects legacy_fixes = [] for fix_data in request.legacy_fixes: @@ -86,33 +90,36 @@ async def evaluate_fix_options( estimated_time=fix_data.get("estimated_time", 30), command=fix_data.get("command"), is_safe=fix_data.get("is_safe", True), - rollback_command=fix_data.get("rollback_command") + rollback_command=fix_data.get("rollback_command"), ) legacy_fixes.append(legacy_fix) - + # Evaluate secure options secure_options = await secure_fix_executor.evaluate_fix_options( - legacy_fixes=legacy_fixes, - target_host=request.target_host + legacy_fixes=legacy_fixes, target_host=request.target_host ) - - logger.info(f"Evaluated {len(secure_options)} fix options for {request.target_host} by {current_user.get('username')}") - + + logger.info( + f"Evaluated {len(secure_options)} fix options for {request.target_host} by {current_user.get('username')}" + ) + return { "secure_options": secure_options, "total_options": len(secure_options), "safe_options": len([opt for opt in secure_options if opt.get("is_safe", False)]), - "blocked_options": len([opt for opt in secure_options if opt.get("security_level") == "blocked"]), - "evaluation_timestamp": datetime.utcnow().isoformat() + "blocked_options": len( + [opt for opt in secure_options if opt.get("security_level") == "blocked"] + ), + "evaluation_timestamp": datetime.utcnow().isoformat(), } - + except HTTPException: raise except Exception as e: logger.error(f"Failed to evaluate fix options: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to evaluate fix options: {str(e)}" + detail=f"Failed to evaluate fix options: {str(e)}", ) @@ -120,17 +127,17 @@ async def evaluate_fix_options( async def request_fix_execution( request: FixExecutionRequest, current_user: dict = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_async_db), ): """ Request execution of a secure automated fix - + Requires: scan:write permission """ try: # Check permissions await check_permission_async(current_user, Permission.SCAN_WRITE, db) - + # Request fix execution result = await secure_fix_executor.request_fix_execution( fix_id=request.fix_id, @@ -138,20 +145,20 @@ async def request_fix_execution( parameters=request.parameters, target_host=request.target_host, requested_by=current_user.get("username", "unknown"), - justification=request.justification + justification=request.justification, ) - + logger.info(f"Fix execution requested: {request.fix_id} by {current_user.get('username')}") - + return result - + except HTTPException: raise except Exception as e: logger.error(f"Failed to request fix execution: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to request fix execution: {str(e)}" + detail=f"Failed to request fix execution: {str(e)}", ) @@ -160,11 +167,11 @@ async def approve_fix_request( request_id: str, approval_request: FixApprovalRequest, current_user: dict = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_async_db), ): """ Approve a pending fix execution request - + Requires: admin role or scan:approve permission """ try: @@ -172,28 +179,28 @@ async def approve_fix_request( user_roles = current_user.get("roles", []) if "admin" not in user_roles: await check_permission_async(current_user, Permission.SCAN_APPROVE, db) - + # Approve the request result = await secure_fix_executor.approve_fix_request( request_id=request_id, approved_by=current_user.get("username", "unknown"), - approval_reason=approval_request.approval_reason + approval_reason=approval_request.approval_reason, ) - + if result["success"]: logger.info(f"Fix request approved: {request_id} by {current_user.get('username')}") else: logger.warning(f"Fix approval failed: {request_id} - {result['message']}") - + return result - + except HTTPException: raise except Exception as e: logger.error(f"Failed to approve fix request: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to approve fix request: {str(e)}" + detail=f"Failed to approve fix request: {str(e)}", ) @@ -201,34 +208,34 @@ async def approve_fix_request( async def execute_approved_fix( request_id: str, current_user: dict = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_async_db), ): """ Execute an approved automated fix - + Requires: scan:write permission """ try: # Check permissions await check_permission_async(current_user, Permission.SCAN_WRITE, db) - + # Execute the fix result = await secure_fix_executor.execute_approved_fix(request_id) - + if result["success"]: logger.info(f"Fix executed successfully: {request_id}") else: logger.warning(f"Fix execution failed: {request_id} - {result['message']}") - + return result - + except HTTPException: raise except Exception as e: logger.error(f"Failed to execute fix: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to execute fix: {str(e)}" + detail=f"Failed to execute fix: {str(e)}", ) @@ -237,11 +244,11 @@ async def rollback_fix( request_id: str, rollback_request: FixRollbackRequest, current_user: dict = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_async_db), ): """ Rollback a previously executed fix - + Requires: admin role or scan:rollback permission """ try: @@ -249,27 +256,28 @@ async def rollback_fix( user_roles = current_user.get("roles", []) if "admin" not in user_roles: await check_permission_async(current_user, Permission.SCAN_ROLLBACK, db) - + # Rollback the fix result = await secure_fix_executor.rollback_fix( - request_id=request_id, - rollback_by=current_user.get("username", "unknown") + request_id=request_id, rollback_by=current_user.get("username", "unknown") ) - + if result["success"]: - logger.info(f"Fix rolled back successfully: {request_id} by {current_user.get('username')}") + logger.info( + f"Fix rolled back successfully: {request_id} by {current_user.get('username')}" + ) else: logger.warning(f"Fix rollback failed: {request_id} - {result['message']}") - + return result - + except HTTPException: raise except Exception as e: logger.error(f"Failed to rollback fix: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to rollback fix: {str(e)}" + detail=f"Failed to rollback fix: {str(e)}", ) @@ -277,46 +285,44 @@ async def rollback_fix( async def get_fix_status( request_id: str, current_user: dict = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_async_db), ): """ Get status of a fix execution request - + Requires: scan:read permission """ try: # Check permissions await check_permission_async(current_user, Permission.SCAN_READ, db) - + # Get fix status status_info = await secure_fix_executor.get_fix_status(request_id) - + if not status_info: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Fix request not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Fix request not found" ) - + return status_info - + except HTTPException: raise except Exception as e: logger.error(f"Failed to get fix status: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to get fix status: {str(e)}" + detail=f"Failed to get fix status: {str(e)}", ) @router.get("/pending-approvals") async def list_pending_approvals( - current_user: dict = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """ List all fixes pending approval - + Requires: admin role or scan:approve permission """ try: @@ -324,58 +330,59 @@ async def list_pending_approvals( user_roles = current_user.get("roles", []) if "admin" not in user_roles: await check_permission_async(current_user, Permission.SCAN_APPROVE, db) - + # Get pending approvals pending_fixes = await secure_fix_executor.list_pending_approvals() - + return { "pending_approvals": pending_fixes, "total_pending": len(pending_fixes), - "retrieved_at": datetime.utcnow().isoformat() + "retrieved_at": datetime.utcnow().isoformat(), } - + except HTTPException: raise except Exception as e: logger.error(f"Failed to list pending approvals: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to list pending approvals: {str(e)}" + detail=f"Failed to list pending approvals: {str(e)}", ) @router.get("/secure-commands") async def get_secure_command_catalog( - current_user: dict = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + current_user: dict = Depends(get_current_user), db: AsyncSession = Depends(get_async_db) ): """ Get catalog of available secure commands - + Requires: scan:read permission """ try: # Check permissions await check_permission_async(current_user, Permission.SCAN_READ, db) - + # Get command catalog commands = await secure_fix_executor.get_secure_command_catalog() - + return { "secure_commands": commands, "total_commands": len(commands), "safe_commands": len([cmd for cmd in commands if cmd["security_level"] == "safe"]), - "privileged_commands": len([cmd for cmd in commands if cmd["security_level"] == "privileged"]), - "catalog_timestamp": datetime.utcnow().isoformat() + "privileged_commands": len( + [cmd for cmd in commands if cmd["security_level"] == "privileged"] + ), + "catalog_timestamp": datetime.utcnow().isoformat(), } - + except HTTPException: raise except Exception as e: logger.error(f"Failed to get secure command catalog: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to get secure command catalog: {str(e)}" + detail=f"Failed to get secure command catalog: {str(e)}", ) @@ -383,11 +390,11 @@ async def get_secure_command_catalog( async def cleanup_old_requests( max_age_days: int = 30, current_user: dict = Depends(get_current_user), - db: AsyncSession = Depends(get_async_db) + db: AsyncSession = Depends(get_async_db), ): """ Clean up old execution requests - + Requires: admin role """ try: @@ -395,28 +402,29 @@ async def cleanup_old_requests( user_roles = current_user.get("roles", []) if "admin" not in user_roles: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Admin access required" + status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required" ) - + # Clean up old requests await secure_fix_executor.cleanup_old_requests(max_age_days=max_age_days) - - logger.info(f"Cleaned up old fix requests (max_age_days={max_age_days}) by {current_user.get('username')}") - + + logger.info( + f"Cleaned up old fix requests (max_age_days={max_age_days}) by {current_user.get('username')}" + ) + return { "success": True, "message": f"Cleaned up old requests older than {max_age_days} days", - "cleanup_timestamp": datetime.utcnow().isoformat() + "cleanup_timestamp": datetime.utcnow().isoformat(), } - + except HTTPException: raise except Exception as e: logger.error(f"Failed to cleanup old requests: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to cleanup old requests: {str(e)}" + detail=f"Failed to cleanup old requests: {str(e)}", ) @@ -426,14 +434,14 @@ async def health_check(): try: # Basic health checks sandbox_service_status = "healthy" # Could add more detailed checks - + return { "status": "healthy", "service": "secure-automated-fixes", "sandbox_service": sandbox_service_status, - "timestamp": datetime.utcnow().isoformat() + "timestamp": datetime.utcnow().isoformat(), } - + except Exception as e: logger.error(f"Health check failed: {e}") return JSONResponse( @@ -442,6 +450,6 @@ async def health_check(): "status": "unhealthy", "service": "secure-automated-fixes", "error": str(e), - "timestamp": datetime.utcnow().isoformat() - } - ) \ No newline at end of file + "timestamp": datetime.utcnow().isoformat(), + }, + ) diff --git a/backend/app/routes/bulk_operations.py b/backend/app/routes/bulk_operations.py index 5e9e4a4e..819a41d4 100644 --- a/backend/app/routes/bulk_operations.py +++ b/backend/app/routes/bulk_operations.py @@ -2,6 +2,7 @@ Bulk Operations API Routes Handles bulk import/export operations for hosts and other entities """ + import csv import io import json @@ -19,8 +20,10 @@ router = APIRouter() + class BulkHostImport(BaseModel): """Single host entry for bulk import""" + hostname: str = Field(..., min_length=1, max_length=255) ip_address: str = Field(..., min_length=7, max_length=45) display_name: Optional[str] = Field(None, max_length=255) @@ -31,8 +34,8 @@ class BulkHostImport(BaseModel): environment: Optional[str] = Field("production", max_length=50) tags: Optional[str] = Field(None, max_length=500) # Comma-separated tags owner: Optional[str] = Field(None, max_length=100) - - @validator('ip_address') + + @validator("ip_address") def validate_ip(cls, v): try: ipaddress.ip_address(v) @@ -40,14 +43,18 @@ def validate_ip(cls, v): except ValueError: raise ValueError(f"Invalid IP address: {v}") + class BulkImportRequest(BaseModel): """Request body for JSON-based bulk import""" + hosts: List[BulkHostImport] update_existing: bool = Field(False, description="Update existing hosts instead of skipping") dry_run: bool = Field(False, description="Validate without importing") + class BulkImportResult(BaseModel): """Result of bulk import operation""" + total_processed: int successful_imports: int failed_imports: int @@ -59,6 +66,7 @@ class BulkImportResult(BaseModel): # Enhanced CSV Import Models class FieldAnalysisResponse(BaseModel): """Field analysis response model""" + column_name: str detected_type: str confidence: float @@ -70,6 +78,7 @@ class FieldAnalysisResponse(BaseModel): class CSVAnalysisResponse(BaseModel): """CSV analysis response model""" + total_rows: int total_columns: int headers: List[str] @@ -80,6 +89,7 @@ class CSVAnalysisResponse(BaseModel): class FieldMapping(BaseModel): """Field mapping configuration""" + source_column: str target_field: str transform_function: Optional[str] = None # For future use @@ -87,6 +97,7 @@ class FieldMapping(BaseModel): class EnhancedImportRequest(BaseModel): """Enhanced import request with field mappings""" + csv_data: str field_mappings: List[FieldMapping] update_existing: bool = Field(False, description="Update existing hosts instead of skipping") @@ -99,11 +110,11 @@ class EnhancedImportRequest(BaseModel): async def bulk_import_hosts( request: BulkImportRequest, db: Session = Depends(get_db), - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """ Bulk import hosts from JSON payload - + Supports: - Creating multiple hosts at once - Updating existing hosts (with update_existing flag) @@ -116,18 +127,22 @@ async def bulk_import_hosts( failed_imports=0, skipped_duplicates=0, errors=[], - imported_hosts=[] + imported_hosts=[], ) - + # Process each host for idx, host_data in enumerate(request.hosts): try: # Check if host already exists - existing_host = db.query(Host).filter( - (Host.hostname == host_data.hostname) | - (Host.ip_address == host_data.ip_address) - ).first() - + existing_host = ( + db.query(Host) + .filter( + (Host.hostname == host_data.hostname) + | (Host.ip_address == host_data.ip_address) + ) + .first() + ) + if existing_host: if request.update_existing and not request.dry_run: # Update existing host @@ -135,21 +150,25 @@ async def bulk_import_hosts( setattr(existing_host, field, value) db.commit() result.successful_imports += 1 - result.imported_hosts.append({ - "hostname": existing_host.hostname, - "ip_address": existing_host.ip_address, - "action": "updated" - }) + result.imported_hosts.append( + { + "hostname": existing_host.hostname, + "ip_address": existing_host.ip_address, + "action": "updated", + } + ) else: result.skipped_duplicates += 1 - result.errors.append({ - "row": idx + 1, - "hostname": host_data.hostname, - "error": "Host already exists", - "action": "skipped" - }) + result.errors.append( + { + "row": idx + 1, + "hostname": host_data.hostname, + "error": "Host already exists", + "action": "skipped", + } + ) continue - + if not request.dry_run: # Create new host new_host = Host( @@ -164,48 +183,52 @@ async def bulk_import_hosts( tags=host_data.tags, owner=host_data.owner, is_active=True, - created_by=current_user["id"] + created_by=current_user["id"], ) db.add(new_host) db.commit() - + result.successful_imports += 1 - result.imported_hosts.append({ - "hostname": new_host.hostname, - "ip_address": new_host.ip_address, - "action": "created" - }) + result.imported_hosts.append( + { + "hostname": new_host.hostname, + "ip_address": new_host.ip_address, + "action": "created", + } + ) else: # Dry run - just validate result.successful_imports += 1 - result.imported_hosts.append({ - "hostname": host_data.hostname, - "ip_address": host_data.ip_address, - "action": "would_create" - }) - + result.imported_hosts.append( + { + "hostname": host_data.hostname, + "ip_address": host_data.ip_address, + "action": "would_create", + } + ) + except Exception as e: result.failed_imports += 1 - result.errors.append({ - "row": idx + 1, - "hostname": host_data.hostname if hasattr(host_data, 'hostname') else "unknown", - "error": str(e) - }) + result.errors.append( + { + "row": idx + 1, + "hostname": host_data.hostname if hasattr(host_data, "hostname") else "unknown", + "error": str(e), + } + ) # Continue processing other hosts continue - + # Log the bulk import operation await log_security_event( db=db, event_type="BULK_HOST_IMPORT", user_id=current_user["id"], ip_address=current_user.get("ip_address", "unknown"), - details=f"Imported {result.successful_imports} hosts, {result.failed_imports} failed, {result.skipped_duplicates} skipped" + details=f"Imported {result.successful_imports} hosts, {result.failed_imports} failed, {result.skipped_duplicates} skipped", ) - - return result - + return result @router.get("/hosts/import-template") @@ -216,85 +239,133 @@ async def download_import_template(): # Create CSV content csv_content = io.StringIO() writer = csv.writer(csv_content) - + # Write headers headers = [ - "hostname", "ip_address", "display_name", "operating_system", - "port", "username", "auth_method", "environment", "tags", "owner" + "hostname", + "ip_address", + "display_name", + "operating_system", + "port", + "username", + "auth_method", + "environment", + "tags", + "owner", ] writer.writerow(headers) - + # Write example rows examples = [ - ["web-server-01", "192.168.1.10", "Web Server 01", "RHEL 9", "22", "admin", "ssh_key", "production", "web,frontend", "john.doe"], - ["db-server-01", "192.168.1.20", "Database Server", "RHEL 8", "22", "admin", "password", "production", "database,backend", "jane.smith"], - ["app-server-01", "192.168.1.30", "", "RHEL 9", "22", "", "system_default", "staging", "application", ""], + [ + "web-server-01", + "192.168.1.10", + "Web Server 01", + "RHEL 9", + "22", + "admin", + "ssh_key", + "production", + "web,frontend", + "john.doe", + ], + [ + "db-server-01", + "192.168.1.20", + "Database Server", + "RHEL 8", + "22", + "admin", + "password", + "production", + "database,backend", + "jane.smith", + ], + [ + "app-server-01", + "192.168.1.30", + "", + "RHEL 9", + "22", + "", + "system_default", + "staging", + "application", + "", + ], ] writer.writerows(examples) - + # Return as downloadable file return Response( content=csv_content.getvalue(), media_type="text/csv", - headers={ - "Content-Disposition": "attachment; filename=host_import_template.csv" - } + headers={"Content-Disposition": "attachment; filename=host_import_template.csv"}, ) @router.get("/hosts/export-csv") -@require_role([UserRole.SUPER_ADMIN.value, UserRole.SECURITY_ADMIN.value, UserRole.SECURITY_ANALYST.value]) +@require_role( + [UserRole.SUPER_ADMIN.value, UserRole.SECURITY_ADMIN.value, UserRole.SECURITY_ANALYST.value] +) async def export_hosts_csv( - db: Session = Depends(get_db), - current_user: Dict[str, Any] = Depends(get_current_user) + db: Session = Depends(get_db), current_user: Dict[str, Any] = Depends(get_current_user) ): """ Export all hosts to CSV format Useful for backing up host configurations or as a template """ hosts = db.query(Host).filter(Host.is_active == True).all() - + # Create CSV content csv_content = io.StringIO() writer = csv.writer(csv_content) - + # Write headers headers = [ - "hostname", "ip_address", "display_name", "operating_system", - "port", "username", "auth_method", "environment", "tags", "owner" + "hostname", + "ip_address", + "display_name", + "operating_system", + "port", + "username", + "auth_method", + "environment", + "tags", + "owner", ] writer.writerow(headers) - + # Write host data for host in hosts: - writer.writerow([ - host.hostname, - host.ip_address, - host.display_name or "", - host.operating_system or "", - host.ssh_port or 22, - host.ssh_username or "", - host.auth_method or "password", - host.environment or "production", - host.tags or "", - host.owner or "" - ]) - + writer.writerow( + [ + host.hostname, + host.ip_address, + host.display_name or "", + host.operating_system or "", + host.ssh_port or 22, + host.ssh_username or "", + host.auth_method or "password", + host.environment or "production", + host.tags or "", + host.owner or "", + ] + ) + # Log export operation await log_security_event( db=db, event_type="HOST_EXPORT", user_id=current_user["id"], ip_address=current_user.get("ip_address", "unknown"), - details=f"Exported {len(hosts)} hosts to CSV" + details=f"Exported {len(hosts)} hosts to CSV", ) - + return Response( content=csv_content.getvalue(), media_type="text/csv", - headers={ - "Content-Disposition": "attachment; filename=hosts_export.csv" - } + headers={"Content-Disposition": "attachment; filename=hosts_export.csv"}, ) @@ -302,30 +373,29 @@ async def export_hosts_csv( @router.post("/hosts/analyze-csv", response_model=CSVAnalysisResponse) @require_role([UserRole.SUPER_ADMIN.value, UserRole.SECURITY_ADMIN.value]) async def analyze_csv( - file: UploadFile = File(...), - current_user: Dict[str, Any] = Depends(get_current_user) + file: UploadFile = File(...), current_user: Dict[str, Any] = Depends(get_current_user) ): """ Analyze uploaded CSV file and provide intelligent field mapping suggestions - + This endpoint accepts any CSV format and returns: - Column analysis with detected field types - Confidence scores for each detection - Auto-mapping suggestions - Template matches for known formats """ - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise HTTPException(status_code=400, detail="File must be a CSV") - + try: # Read CSV content contents = await file.read() - csv_text = contents.decode('utf-8-sig') # Handle BOM if present - + csv_text = contents.decode("utf-8-sig") # Handle BOM if present + # Analyze with CSV analyzer analyzer = CSVAnalyzer() analysis = analyzer.analyze_csv(csv_text) - + # Convert to response model field_analyses = [ FieldAnalysisResponse( @@ -335,20 +405,20 @@ async def analyze_csv( sample_values=fa.sample_values, unique_count=fa.unique_count, null_count=fa.null_count, - suggestions=fa.suggestions + suggestions=fa.suggestions, ) for fa in analysis.field_analyses ] - + return CSVAnalysisResponse( total_rows=analysis.total_rows, total_columns=analysis.total_columns, headers=analysis.headers, field_analyses=field_analyses, auto_mappings=analysis.auto_mappings, - template_matches=analysis.template_matches + template_matches=analysis.template_matches, ) - + except Exception as e: raise HTTPException(status_code=400, detail=f"CSV analysis failed: {str(e)}") @@ -358,11 +428,11 @@ async def analyze_csv( async def import_with_mapping( request: EnhancedImportRequest, db: Session = Depends(get_db), - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """ Import hosts using custom field mappings - + This endpoint allows importing CSV data with flexible field mapping, supporting any CSV format with user-defined column mappings. """ @@ -370,71 +440,75 @@ async def import_with_mapping( # Parse CSV data csv_reader = csv.DictReader(io.StringIO(request.csv_data)) rows = list(csv_reader) - + if not rows: raise HTTPException(status_code=400, detail="CSV contains no data rows") - + # Create field mapping dictionary field_map = {fm.source_column: fm.target_field for fm in request.field_mappings} - + # Prepare default values defaults = request.default_values or {} - defaults.setdefault('environment', 'production') - defaults.setdefault('port', 22) - defaults.setdefault('auth_method', 'ssh_key') - + defaults.setdefault("environment", "production") + defaults.setdefault("port", 22) + defaults.setdefault("auth_method", "ssh_key") + result = BulkImportResult( total_processed=len(rows), successful_imports=0, failed_imports=0, skipped_duplicates=0, errors=[], - imported_hosts=[] + imported_hosts=[], ) - + # Process each row for idx, row in enumerate(rows): try: # Map fields according to user configuration mapped_data = {} - + # Apply field mappings for source_col, target_field in field_map.items(): if source_col in row and row[source_col]: value = row[source_col].strip() - + # Apply type conversions - if target_field == 'port' and value: + if target_field == "port" and value: try: mapped_data[target_field] = int(value) except ValueError: raise ValueError(f"Invalid port number: {value}") else: mapped_data[target_field] = value - + # Apply default values for missing required fields for field, default_value in defaults.items(): if field not in mapped_data: mapped_data[field] = default_value - + # Ensure required fields are present - if 'hostname' not in mapped_data or not mapped_data['hostname']: + if "hostname" not in mapped_data or not mapped_data["hostname"]: raise ValueError("Hostname is required") - if 'ip_address' not in mapped_data or not mapped_data['ip_address']: + if "ip_address" not in mapped_data or not mapped_data["ip_address"]: raise ValueError("IP address is required") - + # Validate IP address try: - ipaddress.ip_address(mapped_data['ip_address']) + ipaddress.ip_address(mapped_data["ip_address"]) except ValueError: raise ValueError(f"Invalid IP address: {mapped_data['ip_address']}") - + # Check for existing host - existing_host = db.query(Host).filter( - (Host.hostname == mapped_data['hostname']) | - (Host.ip_address == mapped_data['ip_address']) - ).first() - + existing_host = ( + db.query(Host) + .filter( + (Host.hostname == mapped_data["hostname"]) + | (Host.ip_address == mapped_data["ip_address"]) + ) + .first() + ) + if existing_host: if request.update_existing and not request.dry_run: # Update existing host @@ -443,74 +517,84 @@ async def import_with_mapping( setattr(existing_host, field, value) db.commit() result.successful_imports += 1 - result.imported_hosts.append({ - "hostname": existing_host.hostname, - "ip_address": existing_host.ip_address, - "action": "updated" - }) + result.imported_hosts.append( + { + "hostname": existing_host.hostname, + "ip_address": existing_host.ip_address, + "action": "updated", + } + ) else: result.skipped_duplicates += 1 - result.errors.append({ - "row": idx + 1, - "hostname": mapped_data['hostname'], - "error": "Host already exists", - "action": "skipped" - }) + result.errors.append( + { + "row": idx + 1, + "hostname": mapped_data["hostname"], + "error": "Host already exists", + "action": "skipped", + } + ) continue - + if not request.dry_run: # Create new host new_host = Host( - hostname=mapped_data['hostname'], - ip_address=mapped_data['ip_address'], - display_name=mapped_data.get('display_name') or mapped_data['hostname'], - operating_system=mapped_data.get('operating_system') or "RHEL", - port=mapped_data.get('port', 22), - username=mapped_data.get('username'), - auth_method=mapped_data.get('auth_method', 'ssh_key'), - environment=mapped_data.get('environment', 'production'), - tags=mapped_data.get('tags'), - owner=mapped_data.get('owner'), + hostname=mapped_data["hostname"], + ip_address=mapped_data["ip_address"], + display_name=mapped_data.get("display_name") or mapped_data["hostname"], + operating_system=mapped_data.get("operating_system") or "RHEL", + port=mapped_data.get("port", 22), + username=mapped_data.get("username"), + auth_method=mapped_data.get("auth_method", "ssh_key"), + environment=mapped_data.get("environment", "production"), + tags=mapped_data.get("tags"), + owner=mapped_data.get("owner"), is_active=True, - created_by=current_user["id"] + created_by=current_user["id"], ) db.add(new_host) db.commit() - + result.successful_imports += 1 - result.imported_hosts.append({ - "hostname": new_host.hostname, - "ip_address": new_host.ip_address, - "action": "created" - }) + result.imported_hosts.append( + { + "hostname": new_host.hostname, + "ip_address": new_host.ip_address, + "action": "created", + } + ) else: # Dry run - just validate result.successful_imports += 1 - result.imported_hosts.append({ - "hostname": mapped_data['hostname'], - "ip_address": mapped_data['ip_address'], - "action": "would_create" - }) - + result.imported_hosts.append( + { + "hostname": mapped_data["hostname"], + "ip_address": mapped_data["ip_address"], + "action": "would_create", + } + ) + except Exception as e: result.failed_imports += 1 - result.errors.append({ - "row": idx + 1, - "hostname": row.get(field_map.get('hostname', ''), 'unknown'), - "error": str(e) - }) + result.errors.append( + { + "row": idx + 1, + "hostname": row.get(field_map.get("hostname", ""), "unknown"), + "error": str(e), + } + ) continue - + # Log the import operation await log_security_event( db=db, event_type="ENHANCED_BULK_IMPORT", user_id=current_user["id"], ip_address=current_user.get("ip_address", "unknown"), - details=f"Enhanced import: {result.successful_imports} hosts, {result.failed_imports} failed, {result.skipped_duplicates} skipped" + details=f"Enhanced import: {result.successful_imports} hosts, {result.failed_imports} failed, {result.skipped_duplicates} skipped", ) - + return result - + except Exception as e: - raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}") \ No newline at end of file + raise HTTPException(status_code=400, detail=f"Import failed: {str(e)}") diff --git a/backend/app/routes/capabilities.py b/backend/app/routes/capabilities.py index 917ae178..442171d9 100644 --- a/backend/app/routes/capabilities.py +++ b/backend/app/routes/capabilities.py @@ -2,6 +2,7 @@ OpenWatch Capabilities API Provides feature discovery and capability-based routing for OSS/Enterprise features """ + import logging from typing import Dict, Any, List from fastapi import APIRouter, Depends, HTTPException, status @@ -21,6 +22,7 @@ class FeatureFlags(BaseModel): """Feature flags for conditional functionality""" + scanning: bool = True reporting: bool = True host_management: bool = True @@ -28,7 +30,7 @@ class FeatureFlags(BaseModel): audit_logging: bool = True mfa: bool = True plugin_system: bool = True - + # Enterprise features (license-dependent) remediation: bool = False ai_assistance: bool = False @@ -40,6 +42,7 @@ class FeatureFlags(BaseModel): class SystemLimits(BaseModel): """System limits and constraints""" + max_hosts: int = 50 concurrent_scans: int = 5 max_users: int = 10 @@ -50,12 +53,13 @@ class SystemLimits(BaseModel): class IntegrationStatus(BaseModel): """Status of external integrations""" + aegis_available: bool = False aegis_version: str = None ldap_enabled: bool = False smtp_configured: bool = False prometheus_enabled: bool = False - + # Container runtime detection container_runtime: str = "unknown" kubernetes_available: bool = False @@ -63,6 +67,7 @@ class IntegrationStatus(BaseModel): class CapabilitiesResponse(BaseModel): """Complete capabilities response""" + version: str build_info: Dict[str, Any] features: FeatureFlags @@ -74,48 +79,47 @@ class CapabilitiesResponse(BaseModel): @router.get("/capabilities", response_model=CapabilitiesResponse) async def get_capabilities( - current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ) -> CapabilitiesResponse: """ Get system capabilities and feature flags - + Returns comprehensive information about: - Available features and their status - - System limits and constraints + - System limits and constraints - Integration status with external systems - License information and limitations - + This endpoint enables frontend conditional rendering and API consumers to discover available functionality. """ try: settings = get_settings() - + # Detect license type and enterprise features license_info = await _detect_license_info() - + # Check integration status integrations = await _check_integrations() - + # Determine feature flags based on license and configuration features = await _determine_feature_flags(license_info, settings) - + # Calculate system limits limits = await _calculate_system_limits(license_info, settings) - + # Get system information system_info = await _get_system_info() - + # Build version info build_info = { "version": "1.0.0", "build_date": "2025-08-20", "git_commit": "d84d2a3", "api_version": "v1", - "environment": getattr(settings, 'environment', 'production') + "environment": getattr(settings, "environment", "production"), } - + response = CapabilitiesResponse( version="1.0.0", build_info=build_info, @@ -123,28 +127,26 @@ async def get_capabilities( limits=limits, integrations=integrations, license_info=license_info, - system_info=system_info + system_info=system_info, ) - + logger.info(f"Capabilities requested by user {current_user.get('user_id', 'unknown')}") - + return response - + except Exception as e: logger.error(f"Error getting capabilities: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve system capabilities" + detail="Failed to retrieve system capabilities", ) @router.get("/features", response_model=FeatureFlags) -async def get_feature_flags( - current_user: dict = Depends(get_current_user) -) -> FeatureFlags: +async def get_feature_flags(current_user: dict = Depends(get_current_user)) -> FeatureFlags: """ Get just the feature flags (lightweight endpoint) - + Returns only the feature availability flags without detailed system information. Useful for frequent polling by frontend applications. @@ -153,71 +155,68 @@ async def get_feature_flags( settings = get_settings() license_info = await _detect_license_info() features = await _determine_feature_flags(license_info, settings) - + logger.debug(f"Feature flags requested by user {current_user.get('user_id', 'unknown')}") - + return features - + except Exception as e: logger.error(f"Error getting feature flags: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve feature flags" + detail="Failed to retrieve feature flags", ) @router.get("/health/integrations", response_model=IntegrationStatus) async def get_integration_status( - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ) -> IntegrationStatus: """ Get status of external integrations - + Returns the current status of all external system integrations including AEGIS, LDAP, SMTP, and container runtime information. """ try: integrations = await _check_integrations() - - logger.debug(f"Integration status requested by user {current_user.get('user_id', 'unknown')}") - + + logger.debug( + f"Integration status requested by user {current_user.get('user_id', 'unknown')}" + ) + return integrations - + except Exception as e: logger.error(f"Error getting integration status: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve integration status" + detail="Failed to retrieve integration status", ) # Helper functions -async def _detect_license_info() -> Dict[str, Any]: +def _detect_license_info() -> Dict[str, Any]: """Detect license type and enterprise features availability""" # For OSS version, return basic license info # In enterprise version, this would check actual license files - + return { "type": "oss", "tier": "community", "expires": None, - "features_enabled": [ - "scanning", - "reporting", - "host_management", - "plugin_system" - ], + "features_enabled": ["scanning", "reporting", "host_management", "plugin_system"], "enterprise_available": False, - "upgrade_url": "https://hanalyx.com/openwatch/enterprise" + "upgrade_url": "https://hanalyx.com/openwatch/enterprise", } async def _determine_feature_flags(license_info: Dict, settings) -> FeatureFlags: """Determine which features are available based on license and config""" - + # Base OSS features (always available) features = FeatureFlags() - + # Check for enterprise license if license_info.get("type") == "enterprise": features.remediation = True @@ -226,23 +225,23 @@ async def _determine_feature_flags(license_info: Dict, settings) -> FeatureFlags features.siem_integration = True features.compliance_frameworks = True features.enterprise_auth = True - + # Check configuration-dependent features - features.mfa = getattr(settings, 'mfa_enabled', True) - + features.mfa = getattr(settings, "mfa_enabled", True) + # Check if AEGIS is available (affects remediation) if await _check_aegis_availability(): # Even in OSS, basic remediation might be available if AEGIS is configured features.remediation = license_info.get("type") == "enterprise" - + return features -async def _calculate_system_limits(license_info: Dict, settings) -> SystemLimits: +def _calculate_system_limits(license_info: Dict, settings) -> SystemLimits: """Calculate system limits based on license and configuration""" - + limits = SystemLimits() - + # Adjust limits based on license type if license_info.get("type") == "enterprise": limits.max_hosts = 1000 @@ -258,52 +257,52 @@ async def _calculate_system_limits(license_info: Dict, settings) -> SystemLimits limits.storage_limit_gb = 500 limits.api_rate_limit = 5000 limits.plugin_limit = 50 - + # OSS defaults are already set in the model - + return limits async def _check_integrations() -> IntegrationStatus: """Check status of external integrations""" - + integrations = IntegrationStatus() - + # Check AEGIS availability integrations.aegis_available = await _check_aegis_availability() if integrations.aegis_available: integrations.aegis_version = await _get_aegis_version() - + # Check LDAP configuration integrations.ldap_enabled = _check_ldap_config() - + # Check SMTP configuration integrations.smtp_configured = _check_smtp_config() - + # Check Prometheus integrations.prometheus_enabled = _check_prometheus_config() - + # Detect container runtime integrations.container_runtime = await _detect_container_runtime() - + # Check Kubernetes availability integrations.kubernetes_available = await _check_kubernetes_availability() - + return integrations -async def _check_aegis_availability() -> bool: +def _check_aegis_availability() -> bool: """Check if AEGIS remediation service is available""" try: # In a real implementation, this would check AEGIS connectivity # For now, check if AEGIS configuration exists - aegis_url = os.environ.get('AEGIS_URL') + aegis_url = os.environ.get("AEGIS_URL") return aegis_url is not None except: return False -async def _get_aegis_version() -> str: +def _get_aegis_version() -> str: """Get AEGIS version if available""" try: # In a real implementation, this would query AEGIS API @@ -314,17 +313,17 @@ async def _get_aegis_version() -> str: def _check_ldap_config() -> bool: """Check if LDAP is configured""" - return bool(os.environ.get('LDAP_SERVER')) + return bool(os.environ.get("LDAP_SERVER")) def _check_smtp_config() -> bool: """Check if SMTP is configured""" - return bool(os.environ.get('SMTP_SERVER')) + return bool(os.environ.get("SMTP_SERVER")) def _check_prometheus_config() -> bool: """Check if Prometheus monitoring is enabled""" - return bool(os.environ.get('PROMETHEUS_ENABLED', '').lower() == 'true') + return bool(os.environ.get("PROMETHEUS_ENABLED", "").lower() == "true") async def _detect_container_runtime() -> str: @@ -332,27 +331,23 @@ async def _detect_container_runtime() -> str: try: # Check for Podman result = await asyncio.create_subprocess_exec( - 'podman', '--version', - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + "podman", "--version", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) if result.returncode == 0: return "podman" except: pass - + try: # Check for Docker result = await asyncio.create_subprocess_exec( - 'docker', '--version', - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + "docker", "--version", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) if result.returncode == 0: return "docker" except: pass - + return "unknown" @@ -361,20 +356,22 @@ async def _check_kubernetes_availability() -> bool: try: # Check for kubectl result = await asyncio.create_subprocess_exec( - 'kubectl', 'version', '--client', + "kubectl", + "version", + "--client", stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + stderr=asyncio.subprocess.PIPE, ) return result.returncode == 0 except: return False -async def _get_system_info() -> Dict[str, Any]: +def _get_system_info() -> Dict[str, Any]: """Get basic system information""" import platform import psutil - + try: return { "platform": platform.system(), @@ -383,9 +380,9 @@ async def _get_system_info() -> Dict[str, Any]: "python_version": platform.python_version(), "cpu_count": psutil.cpu_count(), "memory_total": psutil.virtual_memory().total, - "disk_usage": dict(psutil.disk_usage('/')._asdict()), - "uptime": psutil.boot_time() + "disk_usage": dict(psutil.disk_usage("/")._asdict()), + "uptime": psutil.boot_time(), } except Exception as e: logger.warning(f"Could not get system info: {e}") - return {"error": "System information unavailable"} \ No newline at end of file + return {"error": "System information unavailable"} diff --git a/backend/app/routes/compliance.py b/backend/app/routes/compliance.py index 4e55ae7b..b7084f6f 100644 --- a/backend/app/routes/compliance.py +++ b/backend/app/routes/compliance.py @@ -2,6 +2,7 @@ Universal Compliance Intelligence API Routes Provides semantic SCAP intelligence and cross-framework compliance data """ + import json import logging from typing import Dict, List, Any, Optional @@ -21,6 +22,7 @@ class SemanticRule(BaseModel): """Semantic rule response model""" + id: str semantic_name: str scap_rule_id: str @@ -37,6 +39,7 @@ class SemanticRule(BaseModel): class FrameworkIntelligence(BaseModel): """Framework intelligence response model""" + framework: str display_name: str semantic_rules_count: int @@ -50,6 +53,7 @@ class FrameworkIntelligence(BaseModel): class ComplianceOverview(BaseModel): """Compliance intelligence overview response model""" + total_frameworks: int semantic_rules_count: int universal_coverage: int @@ -61,9 +65,11 @@ class ComplianceOverview(BaseModel): async def get_semantic_rules( framework: Optional[str] = Query(None, description="Filter by framework"), business_impact: Optional[str] = Query(None, description="Filter by business impact"), - remediation_available: Optional[bool] = Query(None, description="Filter by remediation availability"), + remediation_available: Optional[bool] = Query( + None, description="Filter by remediation availability" + ), db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Get semantic rules from the rule intelligence database""" try: @@ -78,52 +84,56 @@ async def get_semantic_rules( WHERE 1=1 """ params = {} - + if framework: query += " AND :framework = ANY(applicable_frameworks)" params["framework"] = framework - + if business_impact: query += " AND business_impact = :business_impact" params["business_impact"] = business_impact - + if remediation_available is not None: query += " AND remediation_available = :remediation_available" params["remediation_available"] = remediation_available - + query += " ORDER BY created_at DESC" - + result = db.execute(text(query), params) rules = result.fetchall() - + # Convert to list of dictionaries semantic_rules = [] for rule in rules: - semantic_rules.append({ - "id": str(rule.id), - "semantic_name": rule.semantic_name, - "scap_rule_id": rule.scap_rule_id, - "title": rule.title, - "compliance_intent": rule.compliance_intent, - "business_impact": rule.business_impact, - "risk_level": rule.risk_level, - "frameworks": rule.frameworks if rule.frameworks else [], - "remediation_complexity": rule.remediation_complexity, - "estimated_fix_time": rule.estimated_fix_time, - "remediation_available": rule.remediation_available, - "confidence_score": float(rule.confidence_score) if rule.confidence_score else 1.0 - }) - + semantic_rules.append( + { + "id": str(rule.id), + "semantic_name": rule.semantic_name, + "scap_rule_id": rule.scap_rule_id, + "title": rule.title, + "compliance_intent": rule.compliance_intent, + "business_impact": rule.business_impact, + "risk_level": rule.risk_level, + "frameworks": rule.frameworks if rule.frameworks else [], + "remediation_complexity": rule.remediation_complexity, + "estimated_fix_time": rule.estimated_fix_time, + "remediation_available": rule.remediation_available, + "confidence_score": ( + float(rule.confidence_score) if rule.confidence_score else 1.0 + ), + } + ) + return { "rules": semantic_rules, "total_count": len(semantic_rules), "filters_applied": { "framework": framework, "business_impact": business_impact, - "remediation_available": remediation_available - } + "remediation_available": remediation_available, + }, } - + except Exception as e: logger.error(f"Error retrieving semantic rules: {e}") raise HTTPException(status_code=500, detail=f"Failed to retrieve semantic rules: {str(e)}") @@ -131,8 +141,7 @@ async def get_semantic_rules( @router.get("/framework-intelligence") async def get_framework_intelligence( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get framework intelligence overview and statistics""" try: @@ -150,24 +159,24 @@ async def get_framework_intelligence( WHERE applicable_frameworks IS NOT NULL AND array_length(applicable_frameworks, 1) > 0 GROUP BY unnest(applicable_frameworks) """ - + result = db.execute(text(query)) framework_stats = result.fetchall() - + # Process framework data frameworks = [] framework_config = { - 'stig': 'DISA STIG', - 'cis': 'CIS Controls', - 'nist': 'NIST Cybersecurity', - 'pci_dss': 'PCI DSS' + "stig": "DISA STIG", + "cis": "CIS Controls", + "nist": "NIST Cybersecurity", + "pci_dss": "PCI DSS", } - + for stats in framework_stats: framework_key = stats.framework if framework_key not in framework_config: continue - + # Get cross-framework mappings (rules that appear in multiple frameworks) cross_framework_query = """ SELECT COUNT(*) as cross_framework_count @@ -177,41 +186,43 @@ async def get_framework_intelligence( """ cross_result = db.execute(text(cross_framework_query), {"framework": framework_key}) cross_framework_count = cross_result.fetchone().cross_framework_count or 0 - + remediation_coverage = 0 if stats.rule_count > 0: - remediation_coverage = round((stats.remediation_available_count / stats.rule_count) * 100) - - frameworks.append({ - "framework": framework_key, - "display_name": framework_config[framework_key], - "semantic_rules_count": stats.rule_count, - "cross_framework_mappings": cross_framework_count, - "remediation_coverage": remediation_coverage, - "business_impact_breakdown": { - "high": stats.high_impact_count, - "medium": stats.medium_impact_count, - "low": stats.low_impact_count - }, - "estimated_remediation_time": stats.total_remediation_time or 0, - "compatible_distributions": ["RHEL 9", "Ubuntu 22.04", "Oracle Linux 8"], - "compliance_score": 85 + (framework_key == 'stig' and 10 or 5) # Mock data - }) - - return { - "frameworks": frameworks, - "last_updated": datetime.utcnow().isoformat() - } - + remediation_coverage = round( + (stats.remediation_available_count / stats.rule_count) * 100 + ) + + frameworks.append( + { + "framework": framework_key, + "display_name": framework_config[framework_key], + "semantic_rules_count": stats.rule_count, + "cross_framework_mappings": cross_framework_count, + "remediation_coverage": remediation_coverage, + "business_impact_breakdown": { + "high": stats.high_impact_count, + "medium": stats.medium_impact_count, + "low": stats.low_impact_count, + }, + "estimated_remediation_time": stats.total_remediation_time or 0, + "compatible_distributions": ["RHEL 9", "Ubuntu 22.04", "Oracle Linux 8"], + "compliance_score": 85 + (framework_key == "stig" and 10 or 5), # Mock data + } + ) + + return {"frameworks": frameworks, "last_updated": datetime.utcnow().isoformat()} + except Exception as e: logger.error(f"Error retrieving framework intelligence: {e}") - raise HTTPException(status_code=500, detail=f"Failed to retrieve framework intelligence: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to retrieve framework intelligence: {str(e)}" + ) @router.get("/overview") async def get_compliance_overview( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get overall compliance intelligence overview metrics""" try: @@ -222,49 +233,49 @@ async def get_compliance_overview( SUM(CASE WHEN remediation_available THEN 1 ELSE 0 END) as remediation_ready_count FROM rule_intelligence """ - + result = db.execute(text(rules_query)) stats = result.fetchone() - + total_rules = stats.total_rules or 0 remediation_ready = stats.remediation_ready_count or 0 - + # Calculate universal coverage universal_coverage = 0 remediation_readiness = 0 - + if total_rules > 0: universal_coverage = round((remediation_ready / total_rules) * 100) remediation_readiness = universal_coverage - + # Get unique frameworks count frameworks_query = """ SELECT COUNT(DISTINCT unnest(applicable_frameworks)) as framework_count FROM rule_intelligence WHERE applicable_frameworks IS NOT NULL """ - + framework_result = db.execute(text(frameworks_query)) framework_count = framework_result.fetchone().framework_count or 0 - + return { "total_frameworks": framework_count, "semantic_rules_count": total_rules, "universal_coverage": universal_coverage, "remediation_readiness": remediation_readiness, - "last_intelligence_update": datetime.utcnow().strftime("%H:%M:%S") + "last_intelligence_update": datetime.utcnow().strftime("%H:%M:%S"), } - + except Exception as e: logger.error(f"Error retrieving compliance overview: {e}") - raise HTTPException(status_code=500, detail=f"Failed to retrieve compliance overview: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to retrieve compliance overview: {str(e)}" + ) @router.get("/semantic-analysis/{scan_id}") async def get_semantic_analysis( - scan_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + scan_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get semantic analysis results for a specific scan""" try: @@ -276,37 +287,43 @@ async def get_semantic_analysis( FROM semantic_scan_analysis WHERE scan_id = :scan_id """ - + result = db.execute(text(query), {"scan_id": scan_id}) analysis = result.fetchone() - + if not analysis: raise HTTPException(status_code=404, detail="Semantic analysis not found for this scan") - + return { "scan_id": str(analysis.scan_id), "host_id": str(analysis.host_id), "semantic_rules_count": analysis.semantic_rules_count, - "frameworks_analyzed": json.loads(analysis.frameworks_analyzed) if analysis.frameworks_analyzed else [], + "frameworks_analyzed": ( + json.loads(analysis.frameworks_analyzed) if analysis.frameworks_analyzed else [] + ), "remediation_available_count": analysis.remediation_available_count, - "processing_metadata": json.loads(analysis.processing_metadata) if analysis.processing_metadata else {}, + "processing_metadata": ( + json.loads(analysis.processing_metadata) if analysis.processing_metadata else {} + ), "analysis_data": json.loads(analysis.analysis_data) if analysis.analysis_data else {}, "created_at": analysis.created_at.isoformat() if analysis.created_at else None, - "updated_at": analysis.updated_at.isoformat() if analysis.updated_at else None + "updated_at": analysis.updated_at.isoformat() if analysis.updated_at else None, } - + except HTTPException: raise except Exception as e: logger.error(f"Error retrieving semantic analysis: {e}") - raise HTTPException(status_code=500, detail=f"Failed to retrieve semantic analysis: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to retrieve semantic analysis: {str(e)}" + ) @router.get("/compliance-matrix") async def get_compliance_matrix( host_id: Optional[str] = Query(None, description="Filter by host ID"), db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Get framework compliance matrix data""" try: @@ -320,63 +337,73 @@ async def get_compliance_matrix( WHERE 1=1 """ params = {} - + if host_id: query += " AND host_id = :host_id" params["host_id"] = host_id - + query += " ORDER BY last_updated DESC" - + result = db.execute(text(query), params) matrix_data = result.fetchall() - + compliance_matrix = [] for row in matrix_data: - compliance_matrix.append({ - "host_id": str(row.host_id), - "framework": row.framework, - "compliance_score": float(row.compliance_score) if row.compliance_score else 0.0, - "total_rules": row.total_rules, - "passed_rules": row.passed_rules, - "failed_rules": row.failed_rules, - "previous_score": float(row.previous_score) if row.previous_score else None, - "trend": row.trend, - "last_scan_id": str(row.last_scan_id) if row.last_scan_id else None, - "last_updated": row.last_updated.isoformat() if row.last_updated else None, - "predicted_next_score": float(row.predicted_next_score) if row.predicted_next_score else None, - "prediction_confidence": float(row.prediction_confidence) if row.prediction_confidence else None - }) - + compliance_matrix.append( + { + "host_id": str(row.host_id), + "framework": row.framework, + "compliance_score": ( + float(row.compliance_score) if row.compliance_score else 0.0 + ), + "total_rules": row.total_rules, + "passed_rules": row.passed_rules, + "failed_rules": row.failed_rules, + "previous_score": float(row.previous_score) if row.previous_score else None, + "trend": row.trend, + "last_scan_id": str(row.last_scan_id) if row.last_scan_id else None, + "last_updated": row.last_updated.isoformat() if row.last_updated else None, + "predicted_next_score": ( + float(row.predicted_next_score) if row.predicted_next_score else None + ), + "prediction_confidence": ( + float(row.prediction_confidence) if row.prediction_confidence else None + ), + } + ) + return { "compliance_matrix": compliance_matrix, "total_entries": len(compliance_matrix), - "last_updated": datetime.utcnow().isoformat() + "last_updated": datetime.utcnow().isoformat(), } - + except Exception as e: logger.error(f"Error retrieving compliance matrix: {e}") - raise HTTPException(status_code=500, detail=f"Failed to retrieve compliance matrix: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to retrieve compliance matrix: {str(e)}" + ) @router.post("/remediation/strategy") async def create_remediation_strategy( request: Dict[str, Any], db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Create an intelligent remediation strategy based on semantic analysis""" try: # Get the semantic SCAP engine semantic_engine = get_semantic_scap_engine() - + # Extract request parameters host_id = request.get("host_id") framework_goals = request.get("frameworks", ["stig"]) risk_tolerance = request.get("risk_tolerance", "medium") - + if not host_id: raise HTTPException(status_code=400, detail="host_id is required") - + # Get semantic rules for the host (mock data for now) rules_query = """ SELECT * FROM rule_intelligence @@ -390,7 +417,7 @@ async def create_remediation_strategy( estimated_fix_time ASC LIMIT 20 """ - + # For now, return a structured remediation strategy strategy = { "host_id": host_id, @@ -403,38 +430,40 @@ async def create_remediation_strategy( "name": "High Impact Quick Wins", "description": "Address high-impact rules with simple remediation", "estimated_time": 30, - "rules_count": 5 + "rules_count": 5, }, { "phase": 2, "name": "Medium Impact Remediation", "description": "Address medium-impact security controls", "estimated_time": 60, - "rules_count": 8 + "rules_count": 8, }, { "phase": 3, "name": "Complex Security Hardening", "description": "Address complex rules requiring system changes", "estimated_time": 120, - "rules_count": 7 - } + "rules_count": 7, + }, ], "total_estimated_time": 210, "total_rules": 20, "expected_compliance_improvement": { "stig": {"current": 75, "predicted": 92}, - "cis": {"current": 82, "predicted": 95} - } + "cis": {"current": 82, "predicted": 95}, + }, } - + return strategy - + except HTTPException: raise except Exception as e: logger.error(f"Error creating remediation strategy: {e}") - raise HTTPException(status_code=500, detail=f"Failed to create remediation strategy: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to create remediation strategy: {str(e)}" + ) @router.get("/health") @@ -443,21 +472,17 @@ async def compliance_health_check(): try: # Test semantic engine availability semantic_engine = get_semantic_scap_engine() - + return { "status": "healthy", "services": { "semantic_engine": "available", "database": "connected", - "api": "operational" + "api": "operational", }, - "timestamp": datetime.utcnow().isoformat() + "timestamp": datetime.utcnow().isoformat(), } - + except Exception as e: logger.error(f"Compliance health check failed: {e}") - return { - "status": "unhealthy", - "error": str(e), - "timestamp": datetime.utcnow().isoformat() - } \ No newline at end of file + return {"status": "unhealthy", "error": str(e), "timestamp": datetime.utcnow().isoformat()} diff --git a/backend/app/routes/content.py b/backend/app/routes/content.py index eee1bb4b..24b284b6 100644 --- a/backend/app/routes/content.py +++ b/backend/app/routes/content.py @@ -1,6 +1,7 @@ """ SCAP Content Management Routes """ + from fastapi import APIRouter, HTTPException, Depends, status, UploadFile, File from fastapi.security import HTTPBearer from pydantic import BaseModel @@ -44,7 +45,7 @@ async def list_content(token: str = Depends(security)): content_type="benchmark", file_path="/scap/rhel9-stig.xml", upload_date="2024-01-10T09:00:00Z", - profiles=["stig-rhel9-server", "stig-rhel9-workstation"] + profiles=["stig-rhel9-server", "stig-rhel9-workstation"], ), SCAPContent( id="2", @@ -54,25 +55,22 @@ async def list_content(token: str = Depends(security)): content_type="benchmark", file_path="/scap/ubuntu22-cis.xml", upload_date="2024-01-12T14:30:00Z", - profiles=["cis-ubuntu22-l1-server", "cis-ubuntu22-l2-server"] - ) + profiles=["cis-ubuntu22-l1-server", "cis-ubuntu22-l2-server"], + ), ] - + return mock_content @router.post("/upload") -async def upload_content( - file: UploadFile = File(...), - token: str = Depends(security) -): +async def upload_content(file: UploadFile = File(...), token: str = Depends(security)): """Upload new SCAP content file""" - if not file.filename.endswith(('.xml', '.zip', '.bz2')): + if not file.filename.endswith((".xml", ".zip", ".bz2")): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid file type. Only XML, ZIP, and BZ2 files are allowed." + detail="Invalid file type. Only XML, ZIP, and BZ2 files are allowed.", ) - + # Mock upload processing content_info = { "id": "3", @@ -80,9 +78,9 @@ async def upload_content( "size": file.size, "content_type": file.content_type, "status": "uploaded", - "message": "File uploaded successfully. Processing will begin shortly." + "message": "File uploaded successfully. Processing will begin shortly.", } - + logger.info(f"SCAP content uploaded: {file.filename}") return content_info @@ -100,13 +98,10 @@ async def get_content(content_id: str, token: str = Depends(security)): content_type="benchmark", file_path="/scap/rhel9-stig.xml", upload_date="2024-01-10T09:00:00Z", - profiles=["stig-rhel9-server", "stig-rhel9-workstation"] + profiles=["stig-rhel9-server", "stig-rhel9-workstation"], ) - - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="SCAP content not found" - ) + + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="SCAP content not found") @router.get("/{content_id}/profiles") @@ -117,15 +112,15 @@ async def get_content_profiles(content_id: str, token: str = Depends(security)): { "id": "stig-rhel9-server", "title": "Red Hat Enterprise Linux 9 STIG for Servers", - "description": "Security configuration for RHEL 9 servers" + "description": "Security configuration for RHEL 9 servers", }, { - "id": "stig-rhel9-workstation", + "id": "stig-rhel9-workstation", "title": "Red Hat Enterprise Linux 9 STIG for Workstations", - "description": "Security configuration for RHEL 9 workstations" - } + "description": "Security configuration for RHEL 9 workstations", + }, ] - + return mock_profiles @@ -133,4 +128,4 @@ async def get_content_profiles(content_id: str, token: str = Depends(security)): async def delete_content(content_id: str, token: str = Depends(security)): """Delete SCAP content""" logger.info(f"Deleted SCAP content {content_id}") - return {"message": "SCAP content deleted successfully"} \ No newline at end of file + return {"message": "SCAP content deleted successfully"} diff --git a/backend/app/routes/host_groups.py b/backend/app/routes/host_groups.py index 2a851b73..2269f6a7 100644 --- a/backend/app/routes/host_groups.py +++ b/backend/app/routes/host_groups.py @@ -2,6 +2,7 @@ Host Groups API Routes Handles host group creation, management, and host assignment with smart validation """ + import logging import json from typing import List, Optional, Dict, Any @@ -20,8 +21,11 @@ from ..rbac import check_permission from ..services.group_validation_service import ValidationError from ..models.scan_models import ( - GroupScanConfig, GroupScanSession, GroupScanProgress, - HostScanDetail, ActiveScanSession + GroupScanConfig, + GroupScanSession, + GroupScanProgress, + HostScanDetail, + ActiveScanSession, ) logger = logging.getLogger(__name__) @@ -109,12 +113,13 @@ class CompatibilityValidationResponse(BaseModel): @router.get("/", response_model=List[HostGroupResponse]) async def list_host_groups( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """List all host groups with host counts""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT hg.id, hg.name, hg.description, hg.color, hg.created_by, hg.created_at, hg.updated_at, hg.os_family, hg.os_version_pattern, hg.architecture, hg.scap_content_id, @@ -130,11 +135,15 @@ async def list_host_groups( hg.scap_content_id, hg.default_profile_id, hg.compliance_framework, hg.auto_scan_enabled, hg.scan_schedule, hg.validation_rules, sc.name ORDER BY hg.name - """)) - + """ + ) + ) + groups = [] for row in result: - logger.info(f"Raw row data for group {row.id}: scap_content_id={row.scap_content_id}, default_profile_id={row.default_profile_id}") + logger.info( + f"Raw row data for group {row.id}: scap_content_id={row.scap_content_id}, default_profile_id={row.default_profile_id}" + ) group_data = { "id": row.id, "name": row.name, @@ -150,16 +159,22 @@ async def list_host_groups( "scap_content_id": row.scap_content_id, "default_profile_id": row.default_profile_id, "compliance_framework": row.compliance_framework, - "auto_scan_enabled": row.auto_scan_enabled if row.auto_scan_enabled is not None else False, + "auto_scan_enabled": ( + row.auto_scan_enabled if row.auto_scan_enabled is not None else False + ), "scan_schedule": row.scan_schedule, - "validation_rules": json.loads(row.validation_rules) if row.validation_rules else None, - "scap_content_name": row.scap_content_name + "validation_rules": ( + json.loads(row.validation_rules) if row.validation_rules else None + ), + "scap_content_name": row.scap_content_name, } - logger.info(f"Group data includes SCAP fields: scap_content_id={group_data.get('scap_content_id')}, default_profile_id={group_data.get('default_profile_id')}") + logger.info( + f"Group data includes SCAP fields: scap_content_id={group_data.get('scap_content_id')}, default_profile_id={group_data.get('default_profile_id')}" + ) groups.append(group_data) - + return groups - + except Exception as e: logger.error(f"Error listing host groups: {e}") raise HTTPException(status_code=500, detail="Failed to list host groups") @@ -167,13 +182,13 @@ async def list_host_groups( @router.get("/{group_id}", response_model=HostGroupResponse) async def get_host_group( - group_id: int, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + group_id: int, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get a specific host group by ID""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT hg.id, hg.name, hg.description, hg.color, hg.created_by, hg.created_at, hg.updated_at, hg.os_family, hg.os_version_pattern, hg.architecture, hg.scap_content_id, @@ -189,13 +204,16 @@ async def get_host_group( hg.updated_at, hg.os_family, hg.os_version_pattern, hg.architecture, hg.scap_content_id, hg.default_profile_id, hg.compliance_framework, hg.auto_scan_enabled, hg.scan_schedule, hg.validation_rules, sc.name - """), {"group_id": group_id}) - + """ + ), + {"group_id": group_id}, + ) + row = result.fetchone() - + if not row: raise HTTPException(status_code=404, detail="Host group not found") - + return { "id": row.id, "name": row.name, @@ -211,12 +229,14 @@ async def get_host_group( "scap_content_id": row.scap_content_id, "default_profile_id": row.default_profile_id, "compliance_framework": row.compliance_framework, - "auto_scan_enabled": row.auto_scan_enabled if row.auto_scan_enabled is not None else False, + "auto_scan_enabled": ( + row.auto_scan_enabled if row.auto_scan_enabled is not None else False + ), "scan_schedule": row.scan_schedule, "validation_rules": json.loads(row.validation_rules) if row.validation_rules else None, - "scap_content_name": row.scap_content_name + "scap_content_name": row.scap_content_name, } - + except HTTPException: raise except Exception as e: @@ -228,29 +248,41 @@ async def get_host_group( async def create_host_group( group_data: HostGroupCreate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Create a new host group""" try: # Check if group name already exists - existing = db.execute(text(""" + existing = db.execute( + text( + """ SELECT id FROM host_groups WHERE name = :name - """), {"name": group_data.name}).fetchone() - + """ + ), + {"name": group_data.name}, + ).fetchone() + if existing: raise HTTPException(status_code=400, detail="Group name already exists") - + # Validate SCAP content if provided if group_data.scap_content_id: - scap_check = db.execute(text(""" + scap_check = db.execute( + text( + """ SELECT id, name FROM scap_content WHERE id = :content_id - """), {"content_id": group_data.scap_content_id}).fetchone() - + """ + ), + {"content_id": group_data.scap_content_id}, + ).fetchone() + if not scap_check: raise HTTPException(status_code=400, detail="Invalid SCAP content ID") - + # Create the group - result = db.execute(text(""" + result = db.execute( + text( + """ INSERT INTO host_groups ( name, description, color, created_by, created_at, updated_at, os_family, os_version_pattern, architecture, scap_content_id, @@ -267,27 +299,32 @@ async def create_host_group( os_family, os_version_pattern, architecture, scap_content_id, default_profile_id, compliance_framework, auto_scan_enabled, scan_schedule, validation_rules - """), { - "name": group_data.name, - "description": group_data.description, - "color": group_data.color, - "created_by": current_user["id"], - "created_at": datetime.utcnow(), - "updated_at": datetime.utcnow(), - "os_family": group_data.os_family, - "os_version_pattern": group_data.os_version_pattern, - "architecture": group_data.architecture, - "scap_content_id": group_data.scap_content_id, - "default_profile_id": group_data.default_profile_id, - "compliance_framework": group_data.compliance_framework, - "auto_scan_enabled": group_data.auto_scan_enabled or False, - "scan_schedule": group_data.scan_schedule, - "validation_rules": json.dumps(group_data.validation_rules) if group_data.validation_rules else None - }) - + """ + ), + { + "name": group_data.name, + "description": group_data.description, + "color": group_data.color, + "created_by": current_user["id"], + "created_at": datetime.utcnow(), + "updated_at": datetime.utcnow(), + "os_family": group_data.os_family, + "os_version_pattern": group_data.os_version_pattern, + "architecture": group_data.architecture, + "scap_content_id": group_data.scap_content_id, + "default_profile_id": group_data.default_profile_id, + "compliance_framework": group_data.compliance_framework, + "auto_scan_enabled": group_data.auto_scan_enabled or False, + "scan_schedule": group_data.scan_schedule, + "validation_rules": ( + json.dumps(group_data.validation_rules) if group_data.validation_rules else None + ), + }, + ) + group = result.fetchone() db.commit() - + return { "id": group.id, "name": group.name, @@ -303,11 +340,15 @@ async def create_host_group( "scap_content_id": group.scap_content_id, "default_profile_id": group.default_profile_id, "compliance_framework": group.compliance_framework, - "auto_scan_enabled": group.auto_scan_enabled if group.auto_scan_enabled is not None else False, + "auto_scan_enabled": ( + group.auto_scan_enabled if group.auto_scan_enabled is not None else False + ), "scan_schedule": group.scan_schedule, - "validation_rules": json.loads(group.validation_rules) if group.validation_rules else None + "validation_rules": ( + json.loads(group.validation_rules) if group.validation_rules else None + ), } - + except HTTPException: raise except Exception as e: @@ -320,122 +361,154 @@ async def update_host_group( group_id: int, group_data: HostGroupUpdate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Update a host group""" try: # Check if group exists - existing = db.execute(text(""" + existing = db.execute( + text( + """ SELECT id FROM host_groups WHERE id = :group_id - """), {"group_id": group_id}).fetchone() - + """ + ), + {"group_id": group_id}, + ).fetchone() + if not existing: raise HTTPException(status_code=404, detail="Group not found") - + # Check if new name conflicts (if name is being updated) if group_data.name: - name_conflict = db.execute(text(""" + name_conflict = db.execute( + text( + """ SELECT id FROM host_groups WHERE name = :name AND id != :group_id - """), {"name": group_data.name, "group_id": group_id}).fetchone() - + """ + ), + {"name": group_data.name, "group_id": group_id}, + ).fetchone() + if name_conflict: raise HTTPException(status_code=400, detail="Group name already exists") - + # Validate SCAP content if provided if group_data.scap_content_id is not None: - scap_check = db.execute(text(""" + scap_check = db.execute( + text( + """ SELECT id FROM scap_content WHERE id = :content_id - """), {"content_id": group_data.scap_content_id}).fetchone() - + """ + ), + {"content_id": group_data.scap_content_id}, + ).fetchone() + if not scap_check: raise HTTPException(status_code=400, detail="Invalid SCAP content ID") - + # Build update query dynamically update_fields = [] update_params = {"group_id": group_id, "updated_at": datetime.utcnow()} - + if group_data.name is not None: update_fields.append("name = :name") update_params["name"] = group_data.name - + if group_data.description is not None: update_fields.append("description = :description") update_params["description"] = group_data.description - + if group_data.color is not None: update_fields.append("color = :color") update_params["color"] = group_data.color - + if group_data.os_family is not None: update_fields.append("os_family = :os_family") update_params["os_family"] = group_data.os_family - + if group_data.os_version_pattern is not None: update_fields.append("os_version_pattern = :os_version_pattern") update_params["os_version_pattern"] = group_data.os_version_pattern - + if group_data.architecture is not None: update_fields.append("architecture = :architecture") update_params["architecture"] = group_data.architecture - + if group_data.scap_content_id is not None: update_fields.append("scap_content_id = :scap_content_id") update_params["scap_content_id"] = group_data.scap_content_id - + if group_data.default_profile_id is not None: update_fields.append("default_profile_id = :default_profile_id") update_params["default_profile_id"] = group_data.default_profile_id - + if group_data.compliance_framework is not None: update_fields.append("compliance_framework = :compliance_framework") update_params["compliance_framework"] = group_data.compliance_framework - + if group_data.auto_scan_enabled is not None: update_fields.append("auto_scan_enabled = :auto_scan_enabled") update_params["auto_scan_enabled"] = group_data.auto_scan_enabled - + if group_data.scan_schedule is not None: update_fields.append("scan_schedule = :scan_schedule") update_params["scan_schedule"] = group_data.scan_schedule - + if group_data.validation_rules is not None: update_fields.append("validation_rules = :validation_rules") - update_params["validation_rules"] = json.dumps(group_data.validation_rules) if group_data.validation_rules else None - + update_params["validation_rules"] = ( + json.dumps(group_data.validation_rules) if group_data.validation_rules else None + ) + update_fields.append("updated_at = :updated_at") - + if not update_fields: raise HTTPException(status_code=400, detail="No fields to update") - + # Update the group - result = db.execute(text(f""" + result = db.execute( + text( + f""" UPDATE host_groups SET {', '.join(update_fields)} WHERE id = :group_id RETURNING id, name, description, color, created_by, created_at, updated_at, os_family, os_version_pattern, architecture, scap_content_id, default_profile_id, compliance_framework, auto_scan_enabled, scan_schedule, validation_rules - """), update_params) - + """ + ), + update_params, + ) + group = result.fetchone() db.commit() - + # Get host count and SCAP content name - count_result = db.execute(text(""" + count_result = db.execute( + text( + """ SELECT COUNT(*) as host_count FROM host_group_memberships WHERE group_id = :group_id - """), {"group_id": group_id}) + """ + ), + {"group_id": group_id}, + ) host_count = count_result.fetchone().host_count - + # Get SCAP content name if applicable scap_content_name = None if group.scap_content_id: - scap_result = db.execute(text(""" + scap_result = db.execute( + text( + """ SELECT name FROM scap_content WHERE id = :content_id - """), {"content_id": group.scap_content_id}) + """ + ), + {"content_id": group.scap_content_id}, + ) scap_row = scap_result.fetchone() if scap_row: scap_content_name = scap_row.name - + return { "id": group.id, "name": group.name, @@ -451,12 +524,16 @@ async def update_host_group( "scap_content_id": group.scap_content_id, "default_profile_id": group.default_profile_id, "compliance_framework": group.compliance_framework, - "auto_scan_enabled": group.auto_scan_enabled if group.auto_scan_enabled is not None else False, + "auto_scan_enabled": ( + group.auto_scan_enabled if group.auto_scan_enabled is not None else False + ), "scan_schedule": group.scan_schedule, - "validation_rules": json.loads(group.validation_rules) if group.validation_rules else None, - "scap_content_name": scap_content_name + "validation_rules": ( + json.loads(group.validation_rules) if group.validation_rules else None + ), + "scap_content_name": scap_content_name, } - + except HTTPException: raise except Exception as e: @@ -466,34 +543,47 @@ async def update_host_group( @router.delete("/{group_id}") async def delete_host_group( - group_id: int, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + group_id: int, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Delete a host group""" try: # Check if group exists - existing = db.execute(text(""" + existing = db.execute( + text( + """ SELECT id FROM host_groups WHERE id = :group_id - """), {"group_id": group_id}).fetchone() - + """ + ), + {"group_id": group_id}, + ).fetchone() + if not existing: raise HTTPException(status_code=404, detail="Group not found") - + # Remove all host assignments first - db.execute(text(""" + db.execute( + text( + """ DELETE FROM host_group_memberships WHERE group_id = :group_id - """), {"group_id": group_id}) - + """ + ), + {"group_id": group_id}, + ) + # Delete the group - db.execute(text(""" + db.execute( + text( + """ DELETE FROM host_groups WHERE id = :group_id - """), {"group_id": group_id}) - + """ + ), + {"group_id": group_id}, + ) + db.commit() - + return {"message": "Host group deleted successfully"} - + except HTTPException: raise except Exception as e: @@ -506,41 +596,55 @@ async def assign_hosts_to_group( group_id: int, request: AssignHostsRequest, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Assign hosts to a group""" try: # Check if group exists - existing = db.execute(text(""" + existing = db.execute( + text( + """ SELECT id FROM host_groups WHERE id = :group_id - """), {"group_id": group_id}).fetchone() - + """ + ), + {"group_id": group_id}, + ).fetchone() + if not existing: raise HTTPException(status_code=404, detail="Group not found") - + # Remove hosts from any existing groups first (each host can only be in one group) if request.host_ids: - placeholders = ','.join([f"'{host_id}'" for host_id in request.host_ids]) - db.execute(text(f""" + placeholders = ",".join([f"'{host_id}'" for host_id in request.host_ids]) + db.execute( + text( + f""" DELETE FROM host_group_memberships WHERE host_id IN ({placeholders}) - """)) - + """ + ) + ) + # Add hosts to the new group for host_id in request.host_ids: - db.execute(text(""" + db.execute( + text( + """ INSERT INTO host_group_memberships (host_id, group_id, assigned_by, assigned_at) VALUES (:host_id, :group_id, :assigned_by, :assigned_at) - """), { - "host_id": host_id, - "group_id": group_id, - "assigned_by": current_user["id"], - "assigned_at": datetime.utcnow() - }) - + """ + ), + { + "host_id": host_id, + "group_id": group_id, + "assigned_by": current_user["id"], + "assigned_at": datetime.utcnow(), + }, + ) + db.commit() - + return {"message": f"Successfully assigned {len(request.host_ids)} hosts to group"} - + except HTTPException: raise except Exception as e: @@ -553,23 +657,28 @@ async def remove_host_from_group( group_id: int, host_id: str, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Remove a host from a group""" try: # Remove the host from the group - result = db.execute(text(""" + result = db.execute( + text( + """ DELETE FROM host_group_memberships WHERE group_id = :group_id AND host_id = :host_id - """), {"group_id": group_id, "host_id": host_id}) - + """ + ), + {"group_id": group_id, "host_id": host_id}, + ) + db.commit() - + if result.rowcount == 0: raise HTTPException(status_code=404, detail="Host not found in group") - + return {"message": "Host removed from group successfully"} - + except HTTPException: raise except Exception as e: @@ -579,34 +688,30 @@ async def remove_host_from_group( # Smart validation endpoints + @router.post("/{group_id}/validate-hosts", response_model=CompatibilityValidationResponse) async def validate_host_compatibility( group_id: int, request: ValidateHostsRequest, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Validate host compatibility with a group - + Checks OS family, version, architecture, and SCAP content compatibility Returns detailed validation results including suggestions for incompatible hosts """ try: validation_service = GroupValidationService(db) results = validation_service.validate_host_group_compatibility( - host_ids=request.host_ids, - group_id=group_id, - user_role=current_user.get("role") + host_ids=request.host_ids, group_id=group_id, user_role=current_user.get("role") ) - + return results - + except ValidationError as e: - raise HTTPException( - status_code=e.status_code or 400, - detail=e.message - ) + raise HTTPException(status_code=e.status_code or 400, detail=e.message) except Exception as e: logger.error(f"Error validating host compatibility: {e}") raise HTTPException(status_code=500, detail="Failed to validate host compatibility") @@ -616,63 +721,72 @@ async def validate_host_compatibility( async def create_smart_group( request: SmartGroupCreateRequest, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Create a smart group based on host characteristics - + Analyzes selected hosts and automatically configures group settings including OS requirements, SCAP content, and validation rules """ try: validation_service = GroupValidationService(db) - + # Analyze hosts to determine group characteristics analysis = validation_service.create_smart_group_from_hosts( host_ids=request.host_ids, group_name=request.group_name, description=request.description, - created_by=current_user["id"] + created_by=current_user["id"], ) - + # If auto_configure is enabled and hosts are homogeneous, create the group if request.auto_configure and "recommendations" in analysis: recommendations = analysis["recommendations"] - + # Create the group with recommended settings group_data = HostGroupCreate( name=request.group_name, - description=request.description or f"Smart group for {recommendations.get('os_family', 'mixed')} hosts", + description=request.description + or f"Smart group for {recommendations.get('os_family', 'mixed')} hosts", os_family=recommendations.get("os_family"), os_version_pattern=recommendations.get("os_version_pattern"), - scap_content_id=recommendations.get("scap_content", {}).get("id") if "scap_content" in recommendations else None, - compliance_framework=recommendations.get("scap_content", {}).get("compliance_framework") if "scap_content" in recommendations else None + scap_content_id=( + recommendations.get("scap_content", {}).get("id") + if "scap_content" in recommendations + else None + ), + compliance_framework=( + recommendations.get("scap_content", {}).get("compliance_framework") + if "scap_content" in recommendations + else None + ), ) - + # Create the group using the existing endpoint logic group_response = await create_host_group(group_data, db, current_user) - + # Assign the hosts to the group assign_request = AssignHostsRequest( host_ids=request.host_ids, validate_compatibility=False, # Already validated - force_assignment=True + force_assignment=True, ) - + await assign_hosts_to_group(group_response["id"], assign_request, db, current_user) - + return { "group": group_response, "analysis": analysis, - "hosts_assigned": len(request.host_ids) + "hosts_assigned": len(request.host_ids), } - + # Return analysis results without creating the group return { "analysis": analysis, - "message": "Group analysis complete. Review recommendations before creating the group." + "message": "Group analysis complete. Review recommendations before creating the group.", } - + except Exception as e: logger.error(f"Error creating smart group: {e}") raise HTTPException(status_code=500, detail="Failed to create smart group") @@ -680,27 +794,22 @@ async def create_smart_group( @router.get("/{group_id}/compatibility-report") async def get_group_compatibility_report( - group_id: int, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + group_id: int, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """ Get a comprehensive compatibility report for a group - + Shows all hosts in the group with their compatibility status, issues, and recommendations for improving group coherence """ try: validation_service = GroupValidationService(db) report = validation_service.get_group_compatibility_report(group_id) - + return report - + except ValidationError as e: - raise HTTPException( - status_code=e.status_code or 404, - detail=e.message - ) + raise HTTPException(status_code=e.status_code or 404, detail=e.message) except Exception as e: logger.error(f"Error generating compatibility report: {e}") raise HTTPException(status_code=500, detail="Failed to generate compatibility report") @@ -711,11 +820,11 @@ async def validate_and_assign_hosts( group_id: int, request: AssignHostsRequest, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Validate and assign hosts to a group with smart validation - + If validate_compatibility is True (default), checks compatibility before assignment If force_assignment is True, assigns compatible hosts and rejects incompatible ones """ @@ -724,69 +833,85 @@ async def validate_and_assign_hosts( # Validate compatibility first validation_service = GroupValidationService(db) validation_results = validation_service.validate_host_group_compatibility( - host_ids=request.host_ids, - group_id=group_id, - user_role=current_user.get("role") + host_ids=request.host_ids, group_id=group_id, user_role=current_user.get("role") ) - + # Check if there are incompatible hosts if validation_results["incompatible"] and not request.force_assignment: # Return validation results without assigning return { "status": "validation_failed", "message": f"{len(validation_results['incompatible'])} hosts are incompatible", - "validation_results": validation_results + "validation_results": validation_results, } - + # If force_assignment is True, only assign compatible hosts - hosts_to_assign = [h["id"] for h in validation_results["compatible"]] if request.force_assignment else request.host_ids + hosts_to_assign = ( + [h["id"] for h in validation_results["compatible"]] + if request.force_assignment + else request.host_ids + ) else: hosts_to_assign = request.host_ids - + # Check if group exists - existing = db.execute(text(""" + existing = db.execute( + text( + """ SELECT id FROM host_groups WHERE id = :group_id - """), {"group_id": group_id}).fetchone() - + """ + ), + {"group_id": group_id}, + ).fetchone() + if not existing: raise HTTPException(status_code=404, detail="Group not found") - + # Remove hosts from any existing groups first if hosts_to_assign: - placeholders = ','.join([f"'{host_id}'" for host_id in hosts_to_assign]) - db.execute(text(f""" + placeholders = ",".join([f"'{host_id}'" for host_id in hosts_to_assign]) + db.execute( + text( + f""" DELETE FROM host_group_memberships WHERE host_id IN ({placeholders}) - """)) - + """ + ) + ) + # Add hosts to the new group assigned_count = 0 for host_id in hosts_to_assign: - db.execute(text(""" + db.execute( + text( + """ INSERT INTO host_group_memberships (host_id, group_id, assigned_by, assigned_at) VALUES (:host_id, :group_id, :assigned_by, :assigned_at) - """), { - "host_id": host_id, - "group_id": group_id, - "assigned_by": current_user["id"], - "assigned_at": datetime.utcnow() - }) + """ + ), + { + "host_id": host_id, + "group_id": group_id, + "assigned_by": current_user["id"], + "assigned_at": datetime.utcnow(), + }, + ) assigned_count += 1 - + db.commit() - + response = { "status": "success", "message": f"Successfully assigned {assigned_count} hosts to group", "assigned_count": assigned_count, - "total_requested": len(request.host_ids) + "total_requested": len(request.host_ids), } - + if request.validate_compatibility and validation_results.get("incompatible"): response["incompatible_hosts"] = validation_results["incompatible"] response["suggestions"] = validation_results.get("suggestions", {}) - + return response - + except HTTPException: raise except Exception as e: @@ -796,12 +921,13 @@ async def validate_and_assign_hosts( # Group Scan Management Endpoints + @router.post("/{group_id}/scan", response_model=dict) async def initiate_group_scan( group_id: int, scan_config: Optional[GroupScanConfig] = None, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Initiate scan for all hosts in a group @@ -809,28 +935,31 @@ async def initiate_group_scan( """ try: # Check if group exists - group_exists = db.execute(text(""" + group_exists = db.execute( + text( + """ SELECT id, name FROM host_groups WHERE id = :group_id - """), {"group_id": group_id}).fetchone() - + """ + ), + {"group_id": group_id}, + ).fetchone() + if not group_exists: raise HTTPException(status_code=404, detail="Host group not found") - + # Initialize group scan service group_scan_service = GroupScanService(db) - + # Create group scan session session = await group_scan_service.initiate_group_scan( - group_id=group_id, - user_id=current_user["id"], - scan_config=scan_config + group_id=group_id, user_id=current_user["id"], scan_config=scan_config ) - + # Start scan execution await group_scan_service.start_group_scan_execution(session.session_id) - + logger.info(f"Group scan initiated for group {group_id} by user {current_user['id']}") - + return { "session_id": session.session_id, "message": f"Group scan initiated for {session.total_hosts} hosts", @@ -838,10 +967,12 @@ async def initiate_group_scan( "group_name": session.group_name, "total_hosts": session.total_hosts, "status": session.status.value, - "estimated_completion": session.estimated_completion.isoformat() if session.estimated_completion else None, - "started_at": session.start_time.isoformat() + "estimated_completion": ( + session.estimated_completion.isoformat() if session.estimated_completion else None + ), + "started_at": session.start_time.isoformat(), } - + except ValueError as e: # Security Fix: Sanitize error messages to prevent information disclosure logger.error(f"Invalid input for group scan: {e}") @@ -853,9 +984,7 @@ async def initiate_group_scan( @router.get("/scan-sessions/{session_id}/progress", response_model=GroupScanProgress) async def get_scan_progress( - session_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + session_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """ Get real-time progress of group scan @@ -865,7 +994,7 @@ async def get_scan_progress( group_scan_service = GroupScanService(db) progress = await group_scan_service.get_scan_progress(session_id) return progress - + except ValueError as e: # Security Fix: Sanitize error messages to prevent information disclosure logger.error(f"Invalid session ID for scan progress: {e}") @@ -877,9 +1006,7 @@ async def get_scan_progress( @router.get("/scan-sessions/{session_id}/hosts", response_model=List[HostScanDetail]) async def get_host_scan_details( - session_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + session_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """ Get detailed status of each host in group scan @@ -888,7 +1015,7 @@ async def get_host_scan_details( group_scan_service = GroupScanService(db) host_details = await group_scan_service.get_host_scan_details(session_id) return host_details - + except Exception as e: logger.error(f"Error getting host scan details: {e}") raise HTTPException(status_code=500, detail="Failed to get host scan details") @@ -896,26 +1023,26 @@ async def get_host_scan_details( @router.post("/scan-sessions/{session_id}/cancel") async def cancel_group_scan( - session_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + session_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Cancel ongoing group scan""" try: group_scan_service = GroupScanService(db) success = await group_scan_service.cancel_group_scan(session_id) - + if not success: - raise HTTPException(status_code=404, detail="Group scan session not found or already completed") - + raise HTTPException( + status_code=404, detail="Group scan session not found or already completed" + ) + logger.info(f"Group scan {session_id} cancelled by user {current_user['id']}") - + return { "message": "Group scan cancelled successfully", "session_id": session_id, - "cancelled_at": datetime.utcnow().isoformat() + "cancelled_at": datetime.utcnow().isoformat(), } - + except HTTPException: raise except Exception as e: @@ -925,21 +1052,20 @@ async def cancel_group_scan( @router.get("/scan-sessions/active", response_model=List[ActiveScanSession]) async def get_active_scans( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get all active scan sessions for current user""" try: group_scan_service = GroupScanService(db) - + # Get active scans - filter by user unless admin user_id = None if current_user.get("role") not in ["super_admin", "security_admin"]: user_id = current_user["id"] - + active_scans = await group_scan_service.get_active_scans(user_id) return active_scans - + except Exception as e: logger.error(f"Error getting active scans: {e}") raise HTTPException(status_code=500, detail="Failed to get active scans") @@ -947,33 +1073,41 @@ async def get_active_scans( @router.get("/scan-sessions/{session_id}/summary") async def get_group_scan_summary( - session_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + session_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get comprehensive summary of completed group scan""" try: # Get session details - session_result = db.execute(text(""" + session_result = db.execute( + text( + """ SELECT s.session_id, s.group_id, s.group_name, s.total_hosts, s.status, s.start_time, s.completed_at, s.initiated_by FROM group_scan_sessions s WHERE s.session_id = :session_id - """), {"session_id": session_id}).fetchone() - + """ + ), + {"session_id": session_id}, + ).fetchone() + if not session_result: raise HTTPException(status_code=404, detail="Group scan session not found") - + # Get host results summary - results = db.execute(text(""" + results = db.execute( + text( + """ SELECT p.host_id, p.host_name, p.status, p.error_message, sr.total_rules, sr.passed_rules, sr.failed_rules, sr.score FROM group_scan_host_progress p LEFT JOIN scan_results sr ON p.scan_result_id = sr.id WHERE p.session_id = :session_id ORDER BY p.host_name - """), {"session_id": session_id}) - + """ + ), + {"session_id": session_id}, + ) + host_results = [] total_rules = 0 total_passed = 0 @@ -981,9 +1115,9 @@ async def get_group_scan_summary( total_score = 0 successful_scans = 0 failed_scans = 0 - + for row in results: - if row.status == 'completed': + if row.status == "completed": successful_scans += 1 if row.total_rules: total_rules += row.total_rules @@ -991,33 +1125,39 @@ async def get_group_scan_summary( total_failed += row.failed_rules or 0 if row.score: try: - score_value = float(row.score.replace('%', '')) + score_value = float(row.score.replace("%", "")) total_score += score_value except: pass - elif row.status == 'failed': + elif row.status == "failed": failed_scans += 1 - - host_results.append({ - "host_id": row.host_id, - "host_name": row.host_name, - "status": row.status, - "error_message": row.error_message, - "scan_results": { - "total_rules": row.total_rules, - "passed_rules": row.passed_rules, - "failed_rules": row.failed_rules, - "score": row.score - } if row.total_rules else None - }) - + + host_results.append( + { + "host_id": row.host_id, + "host_name": row.host_name, + "status": row.status, + "error_message": row.error_message, + "scan_results": ( + { + "total_rules": row.total_rules, + "passed_rules": row.passed_rules, + "failed_rules": row.failed_rules, + "score": row.score, + } + if row.total_rules + else None + ), + } + ) + # Calculate averages average_score = (total_score / successful_scans) if successful_scans > 0 else 0 duration_minutes = 0 if session_result.completed_at and session_result.start_time: duration = session_result.completed_at - session_result.start_time duration_minutes = int(duration.total_seconds() / 60) - + summary = { "session_id": session_result.session_id, "group_id": session_result.group_id, @@ -1032,12 +1172,14 @@ async def get_group_scan_summary( "average_compliance_score": round(average_score, 1), "scan_duration_minutes": duration_minutes, "started_at": session_result.start_time.isoformat(), - "completed_at": session_result.completed_at.isoformat() if session_result.completed_at else None, - "host_results": host_results + "completed_at": ( + session_result.completed_at.isoformat() if session_result.completed_at else None + ), + "host_results": host_results, } - + return summary - + except HTTPException: raise except Exception as e: @@ -1052,31 +1194,33 @@ async def list_group_scan_sessions( limit: int = 20, offset: int = 0, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """List group scan sessions with filtering options""" try: # Build query conditions where_conditions = [] params = {"limit": limit, "offset": offset} - + if status: where_conditions.append("s.status = :status") params["status"] = status - + if group_id: where_conditions.append("s.group_id = :group_id") params["group_id"] = group_id - + # Add user filtering if not admin if current_user.get("role") not in ["super_admin", "security_admin"]: where_conditions.append("s.initiated_by = :user_id") params["user_id"] = current_user["id"] - + where_clause = "WHERE " + " AND ".join(where_conditions) if where_conditions else "" - + # Get sessions with progress info - result = db.execute(text(f""" + result = db.execute( + text( + f""" SELECT s.session_id, s.group_id, s.group_name, s.total_hosts, s.status, s.start_time, s.completed_at, s.estimated_completion, s.initiated_by, COUNT(CASE WHEN p.status = 'completed' THEN 1 END) as hosts_completed, @@ -1090,43 +1234,52 @@ async def list_group_scan_sessions( s.start_time, s.completed_at, s.estimated_completion, s.initiated_by ORDER BY s.start_time DESC LIMIT :limit OFFSET :offset - """), params) - + """ + ), + params, + ) + sessions = [] for row in result: - progress_percentage = (row.hosts_completed / row.total_hosts) * 100 if row.total_hosts > 0 else 0 - - sessions.append({ - "session_id": row.session_id, - "group_id": row.group_id, - "group_name": row.group_name, - "total_hosts": row.total_hosts, - "status": row.status, - "progress_percentage": round(progress_percentage, 1), - "hosts_completed": row.hosts_completed, - "hosts_failed": row.hosts_failed, - "hosts_scanning": row.hosts_scanning, - "hosts_pending": row.hosts_pending, - "started_at": row.start_time.isoformat(), - "completed_at": row.completed_at.isoformat() if row.completed_at else None, - "estimated_completion": row.estimated_completion.isoformat() if row.estimated_completion else None, - "initiated_by": row.initiated_by - }) - + progress_percentage = ( + (row.hosts_completed / row.total_hosts) * 100 if row.total_hosts > 0 else 0 + ) + + sessions.append( + { + "session_id": row.session_id, + "group_id": row.group_id, + "group_name": row.group_name, + "total_hosts": row.total_hosts, + "status": row.status, + "progress_percentage": round(progress_percentage, 1), + "hosts_completed": row.hosts_completed, + "hosts_failed": row.hosts_failed, + "hosts_scanning": row.hosts_scanning, + "hosts_pending": row.hosts_pending, + "started_at": row.start_time.isoformat(), + "completed_at": row.completed_at.isoformat() if row.completed_at else None, + "estimated_completion": ( + row.estimated_completion.isoformat() if row.estimated_completion else None + ), + "initiated_by": row.initiated_by, + } + ) + # Get total count - count_result = db.execute(text(f""" + count_result = db.execute( + text( + f""" SELECT COUNT(DISTINCT s.session_id) as total FROM group_scan_sessions s {where_clause} - """), params).fetchone() - - return { - "sessions": sessions, - "total": count_result.total, - "limit": limit, - "offset": offset - } - + """ + ), + params, + ).fetchone() + + return {"sessions": sessions, "total": count_result.total, "limit": limit, "offset": offset} + except Exception as e: logger.error(f"Error listing group scan sessions: {e}") - raise HTTPException(status_code=500, detail="Failed to list group scan sessions") \ No newline at end of file + raise HTTPException(status_code=500, detail="Failed to list group scan sessions") diff --git a/backend/app/routes/hosts.py b/backend/app/routes/hosts.py index c31f27c1..f14e4e65 100644 --- a/backend/app/routes/hosts.py +++ b/backend/app/routes/hosts.py @@ -1,6 +1,7 @@ """ Host Management Routes """ + from fastapi import APIRouter, HTTPException, Depends, status from fastapi.security import HTTPBearer from pydantic import BaseModel, Field @@ -13,6 +14,7 @@ from ..database import get_db from sqlalchemy.orm import Session from sqlalchemy import text + # NOTE: json and base64 imports removed - using centralized auth service from ..services.ssh_utils import validate_ssh_key, format_validation_message from ..services.ssh_key_service import extract_ssh_key_metadata @@ -46,7 +48,7 @@ class Host(BaseModel): ssh_key_type: Optional[str] = None ssh_key_bits: Optional[int] = None ssh_key_comment: Optional[str] = None - + # Latest scan information latest_scan_id: Optional[str] = None latest_scan_name: Optional[str] = None @@ -59,7 +61,7 @@ class Host(BaseModel): medium_issues: Optional[int] = None low_issues: Optional[int] = None total_rules: Optional[int] = None - + # Group information group_id: Optional[int] = None group_name: Optional[str] = None @@ -103,7 +105,9 @@ async def list_hosts(db: Session = Depends(get_db), current_user: dict = Depends """List all managed hosts""" try: # Try to get hosts from database with latest scan information and group details - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT h.id, h.hostname, h.ip_address, h.display_name, h.operating_system, h.status, h.port, h.username, h.auth_method, h.created_at, h.updated_at, h.description, s.id as latest_scan_id, s.name as latest_scan_name, s.status as scan_status, @@ -124,23 +128,25 @@ async def list_hosts(db: Session = Depends(get_db), current_user: dict = Depends LEFT JOIN host_group_memberships hgm ON hgm.host_id = h.id LEFT JOIN host_groups hg ON hg.id = hgm.group_id ORDER BY h.created_at DESC - """)) - + """ + ) + ) + hosts = [] for row in result: # Calculate critical issues (high severity issues) critical_issues = row.high_issues or 0 - + # Parse compliance score compliance_score = None if row.compliance_score: try: # Remove % sign if present and convert to float - score_str = str(row.compliance_score).replace('%', '') + score_str = str(row.compliance_score).replace("%", "") compliance_score = float(score_str) except (ValueError, TypeError): pass - + host_data = Host( id=str(row.id), hostname=row.hostname, @@ -155,23 +161,25 @@ async def list_hosts(db: Session = Depends(get_db), current_user: dict = Depends updated_at=row.updated_at.isoformat() if row.updated_at else None, last_check=None, # Column doesn't exist in database ssh_key_fingerprint=None, # Not in database schema - ssh_key_type=None, # Not in database schema - ssh_key_bits=None, # Not in database schema - ssh_key_comment=None, # Not in database schema + ssh_key_type=None, # Not in database schema + ssh_key_bits=None, # Not in database schema + ssh_key_comment=None, # Not in database schema group_id=row.group_id, group_name=row.group_name, group_description=row.group_description, - group_color=row.group_color + group_color=row.group_color, ) - + # Add scan information as additional fields if row.latest_scan_id: host_data.latest_scan_id = str(row.latest_scan_id) host_data.latest_scan_name = row.latest_scan_name host_data.scan_status = row.scan_status host_data.scan_progress = row.scan_progress - host_data.last_scan = row.scan_completed_at.isoformat() if row.scan_completed_at else ( - row.scan_started_at.isoformat() if row.scan_started_at else None + host_data.last_scan = ( + row.scan_completed_at.isoformat() + if row.scan_completed_at + else (row.scan_started_at.isoformat() if row.scan_started_at else None) ) host_data.compliance_score = compliance_score host_data.failed_rules = row.failed_rules or 0 @@ -181,74 +189,83 @@ async def list_hosts(db: Session = Depends(get_db), current_user: dict = Depends host_data.medium_issues = row.medium_issues or 0 host_data.low_issues = row.low_issues or 0 host_data.total_rules = row.total_rules or 0 - + hosts.append(host_data) - + return hosts - + except Exception as e: logger.error(f"Database error in host listing: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve hosts from database" + detail="Failed to retrieve hosts from database", ) @router.post("/", response_model=Host) -async def create_host(host: HostCreate, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user)): +async def create_host( + host: HostCreate, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) +): """Add a new host to management""" try: # Insert into database host_id = str(uuid.uuid4()) current_time = datetime.utcnow() - + # Use display_name if provided, otherwise use hostname display_name = host.display_name or host.hostname - + # Handle credential encryption if provided encrypted_creds = None if host.auth_method == "password" and host.password: from ..services.crypto import encrypt_credentials + cred_data = { "username": host.username, "password": host.password, - "auth_method": "password" + "auth_method": "password", } encrypted_creds = encrypt_credentials(json.dumps(cred_data)) logger.info(f"Encrypting password credentials for new host {host.hostname}") elif host.auth_method == "ssh_key" and host.ssh_key: from ..services.crypto import encrypt_credentials + cred_data = { "username": host.username, "ssh_key": host.ssh_key, - "auth_method": "ssh_key" + "auth_method": "ssh_key", } encrypted_creds = encrypt_credentials(json.dumps(cred_data)) logger.info(f"Encrypting SSH key credentials for new host {host.hostname}") - - db.execute(text(""" + + db.execute( + text( + """ INSERT INTO hosts (id, hostname, ip_address, display_name, operating_system, status, port, username, auth_method, encrypted_credentials, is_active, created_at, updated_at) VALUES (:id, :hostname, :ip_address, :display_name, :operating_system, :status, :port, :username, :auth_method, :encrypted_credentials, :is_active, :created_at, :updated_at) - """), { - "id": host_id, - "hostname": host.hostname, - "ip_address": host.ip_address, - "display_name": display_name, - "operating_system": host.operating_system, - "status": "offline", - "port": int(host.port) if host.port else 22, - "username": host.username, - "auth_method": host.auth_method or "ssh_key", - "encrypted_credentials": encrypted_creds, - "is_active": True, - "created_at": current_time, - "updated_at": current_time - }) - + """ + ), + { + "id": host_id, + "hostname": host.hostname, + "ip_address": host.ip_address, + "display_name": display_name, + "operating_system": host.operating_system, + "status": "offline", + "port": int(host.port) if host.port else 22, + "username": host.username, + "auth_method": host.auth_method or "ssh_key", + "encrypted_credentials": encrypted_creds, + "is_active": True, + "created_at": current_time, + "updated_at": current_time, + }, + ) + db.commit() - + new_host = Host( id=host_id, hostname=host.hostname, @@ -257,23 +274,24 @@ async def create_host(host: HostCreate, db: Session = Depends(get_db), current_u operating_system=host.operating_system, status="offline", created_at=current_time.isoformat(), - updated_at=current_time.isoformat() + updated_at=current_time.isoformat(), ) - + logger.info(f"Created new host in database: {host.hostname}") return new_host - + except Exception as e: logger.error(f"Failed to create host in database: {e}") db.rollback() raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to create host" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to create host" ) @router.get("/{host_id}", response_model=Host) -async def get_host(host_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user)): +async def get_host( + host_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) +): """Get host details by ID""" try: # Validate and convert host_id to UUID @@ -282,11 +300,12 @@ async def get_host(host_id: str, db: Session = Depends(get_db), current_user: di except (ValueError, TypeError) as e: logger.error(f"Invalid host ID format: {host_id} - {e}") raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid host ID format" + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid host ID format" ) - - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT h.id, h.hostname, h.ip_address, h.display_name, h.operating_system, h.status, h.port, h.username, h.auth_method, h.created_at, h.updated_at, h.description, hg.id as group_id, hg.name as group_name, hg.description as group_description, hg.color as group_color @@ -294,15 +313,15 @@ async def get_host(host_id: str, db: Session = Depends(get_db), current_user: di LEFT JOIN host_group_memberships hgm ON hgm.host_id = h.id LEFT JOIN host_groups hg ON hg.id = hgm.group_id WHERE h.id = :id - """), {"id": host_uuid}) - + """ + ), + {"id": host_uuid}, + ) + row = result.fetchone() if not row: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Host not found" - ) - + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Host not found") + return Host( id=str(row.id), hostname=row.hostname, @@ -316,27 +335,31 @@ async def get_host(host_id: str, db: Session = Depends(get_db), current_user: di created_at=row.created_at.isoformat() if row.created_at else None, updated_at=row.updated_at.isoformat() if row.updated_at else None, ssh_key_fingerprint=None, # Not in database schema - ssh_key_type=None, # Not in database schema - ssh_key_bits=None, # Not in database schema - ssh_key_comment=None, # Not in database schema + ssh_key_type=None, # Not in database schema + ssh_key_bits=None, # Not in database schema + ssh_key_comment=None, # Not in database schema group_id=row.group_id, group_name=row.group_name, group_description=row.group_description, - group_color=row.group_color + group_color=row.group_color, ) - + except HTTPException: raise except Exception as e: logger.error(f"Failed to get host: {e}") raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve host" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve host" ) @router.put("/{host_id}", response_model=Host) -async def update_host(host_id: str, host_update: HostUpdate, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user)): +async def update_host( + host_id: str, + host_update: HostUpdate, + db: Session = Depends(get_db), + current_user: dict = Depends(get_current_user), +): """Update host information""" try: # Validate and convert host_id to UUID @@ -345,63 +368,73 @@ async def update_host(host_id: str, host_update: HostUpdate, db: Session = Depen except (ValueError, TypeError) as e: logger.error(f"Invalid host ID format: {host_id} - {e}") raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid host ID format" + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid host ID format" ) - + # Check if host exists - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id FROM hosts WHERE id = :id - """), {"id": host_uuid}) - + """ + ), + {"id": host_uuid}, + ) + if not result.fetchone(): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Host not found" - ) - + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Host not found") + # Get current host data for partial updates - current_host_result = db.execute(text(""" + current_host_result = db.execute( + text( + """ SELECT hostname, ip_address, display_name, operating_system, port, username, auth_method, description FROM hosts WHERE id = :id - """), {"id": host_uuid}) - + """ + ), + {"id": host_uuid}, + ) + current_host = current_host_result.fetchone() if not current_host: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Host not found" - ) - + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Host not found") + # Update host - use existing values if new ones not provided current_time = datetime.utcnow() - + # Handle display_name logic properly - new_hostname = host_update.hostname if host_update.hostname is not None else current_host.hostname - new_display_name = (host_update.display_name if host_update.display_name is not None - else current_host.display_name or new_hostname) - + new_hostname = ( + host_update.hostname if host_update.hostname is not None else current_host.hostname + ) + new_display_name = ( + host_update.display_name + if host_update.display_name is not None + else current_host.display_name or new_hostname + ) + # Handle credential updates if provided encrypted_creds = None if host_update.auth_method: if host_update.auth_method == "password" and host_update.password: # Encrypt password credentials from ..services.crypto import encrypt_credentials + cred_data = { "username": host_update.username or current_host.username, "password": host_update.password, - "auth_method": "password" + "auth_method": "password", } encrypted_creds = encrypt_credentials(json.dumps(cred_data)) logger.info(f"Encrypting password credentials for host {host_id}") elif host_update.auth_method == "ssh_key" and host_update.ssh_key: # Encrypt SSH key credentials from ..services.crypto import encrypt_credentials + cred_data = { "username": host_update.username or current_host.username, "ssh_key": host_update.ssh_key, - "auth_method": "ssh_key" + "auth_method": "ssh_key", } encrypted_creds = encrypt_credentials(json.dumps(cred_data)) logger.info(f"Encrypting SSH key credentials for host {host_id}") @@ -409,21 +442,39 @@ async def update_host(host_id: str, host_update: HostUpdate, db: Session = Depen # Clear host-specific credentials when using system default encrypted_creds = None logger.info(f"Clearing host credentials for system default auth on host {host_id}") - + # Update all fields including encrypted credentials update_params = { "id": host_uuid, "hostname": new_hostname, - "ip_address": host_update.ip_address if host_update.ip_address is not None else current_host.ip_address, + "ip_address": ( + host_update.ip_address + if host_update.ip_address is not None + else current_host.ip_address + ), "display_name": new_display_name, - "operating_system": host_update.operating_system if host_update.operating_system is not None else current_host.operating_system, + "operating_system": ( + host_update.operating_system + if host_update.operating_system is not None + else current_host.operating_system + ), "port": host_update.port if host_update.port is not None else current_host.port, - "username": host_update.username if host_update.username is not None else current_host.username, - "auth_method": host_update.auth_method if host_update.auth_method is not None else current_host.auth_method, - "description": host_update.description if host_update.description is not None else current_host.description, - "updated_at": current_time + "username": ( + host_update.username if host_update.username is not None else current_host.username + ), + "auth_method": ( + host_update.auth_method + if host_update.auth_method is not None + else current_host.auth_method + ), + "description": ( + host_update.description + if host_update.description is not None + else current_host.description + ), + "updated_at": current_time, } - + # Build SQL query with optional encrypted_credentials if encrypted_creds is not None or (host_update.auth_method == "system_default"): update_query = """ @@ -455,13 +506,15 @@ async def update_host(host_id: str, host_update: HostUpdate, db: Session = Depen updated_at = :updated_at WHERE id = :id """ - + db.execute(text(update_query), update_params) - + db.commit() - + # Get updated host with group information - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT h.id, h.hostname, h.ip_address, h.display_name, h.operating_system, h.status, h.port, h.username, h.auth_method, h.created_at, h.updated_at, h.description, hg.id as group_id, hg.name as group_name, hg.description as group_description, hg.color as group_color @@ -469,8 +522,11 @@ async def update_host(host_id: str, host_update: HostUpdate, db: Session = Depen LEFT JOIN host_group_memberships hgm ON hgm.host_id = h.id LEFT JOIN host_groups hg ON hg.id = hgm.group_id WHERE h.id = :id - """), {"id": host_uuid}) - + """ + ), + {"id": host_uuid}, + ) + row = result.fetchone() updated_host = Host( id=str(row.id), @@ -485,31 +541,32 @@ async def update_host(host_id: str, host_update: HostUpdate, db: Session = Depen created_at=row.created_at.isoformat() if row.created_at else None, updated_at=row.updated_at.isoformat() if row.updated_at else None, ssh_key_fingerprint=None, # Not in database schema - ssh_key_type=None, # Not in database schema - ssh_key_bits=None, # Not in database schema - ssh_key_comment=None, # Not in database schema + ssh_key_type=None, # Not in database schema + ssh_key_bits=None, # Not in database schema + ssh_key_comment=None, # Not in database schema group_id=row.group_id, group_name=row.group_name, group_description=row.group_description, - group_color=row.group_color + group_color=row.group_color, ) - + logger.info(f"Updated host {host_id}") return updated_host - + except HTTPException: raise except Exception as e: logger.error(f"Failed to update host: {e}") db.rollback() raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to update host" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update host" ) @router.delete("/{host_id}") -async def delete_host(host_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user)): +async def delete_host( + host_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) +): """Remove host from management""" try: # Validate and convert host_id to UUID @@ -518,65 +575,87 @@ async def delete_host(host_id: str, db: Session = Depends(get_db), current_user: except (ValueError, TypeError) as e: logger.error(f"Invalid host ID format: {host_id} - {e}") raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid host ID format" + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid host ID format" ) - + # Check if host exists - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id FROM hosts WHERE id = :id - """), {"id": host_uuid}) - + """ + ), + {"id": host_uuid}, + ) + if not result.fetchone(): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Host not found" - ) - + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Host not found") + # Check if host has any scans (optional - you might want to prevent deletion) - scan_result = db.execute(text(""" + scan_result = db.execute( + text( + """ SELECT COUNT(*) as count FROM scans WHERE host_id = :host_id - """), {"host_id": host_uuid}) - + """ + ), + {"host_id": host_uuid}, + ) + scan_count = scan_result.fetchone().count if scan_count > 0: # You can either delete the scans or prevent deletion # For now, we'll delete the scans too - db.execute(text(""" + db.execute( + text( + """ DELETE FROM scan_results WHERE scan_id IN ( SELECT id FROM scans WHERE host_id = :host_id ) - """), {"host_id": host_uuid}) - - db.execute(text(""" + """ + ), + {"host_id": host_uuid}, + ) + + db.execute( + text( + """ DELETE FROM scans WHERE host_id = :host_id - """), {"host_id": host_uuid}) - + """ + ), + {"host_id": host_uuid}, + ) + logger.info(f"Deleted {scan_count} scans for host {host_id}") - + # Delete the host - db.execute(text(""" + db.execute( + text( + """ DELETE FROM hosts WHERE id = :id - """), {"id": host_uuid}) - + """ + ), + {"id": host_uuid}, + ) + db.commit() - + logger.info(f"Deleted host {host_id}") return {"message": "Host deleted successfully"} - + except HTTPException: raise except Exception as e: logger.error(f"Failed to delete host: {e}") db.rollback() raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to delete host" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to delete host" ) @router.delete("/{host_id}/ssh-key") -async def delete_host_ssh_key(host_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user)): +async def delete_host_ssh_key( + host_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) +): """Delete SSH key from host""" try: # Validate and convert host_id to UUID @@ -585,31 +664,33 @@ async def delete_host_ssh_key(host_id: str, db: Session = Depends(get_db), curre except (ValueError, TypeError) as e: logger.error(f"Invalid host ID format: {host_id} - {e}") raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid host ID format" + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid host ID format" ) - + # Check if host exists and has SSH key - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, auth_method, ssh_key_fingerprint FROM hosts WHERE id = :id - """), {"id": host_uuid}) - + """ + ), + {"id": host_uuid}, + ) + row = result.fetchone() if not row: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Host not found" - ) - + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Host not found") + if not row.ssh_key_fingerprint: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="No SSH key found to delete" + status_code=status.HTTP_400_BAD_REQUEST, detail="No SSH key found to delete" ) - + # Update host to remove SSH key - db.execute(text(""" + db.execute( + text( + """ UPDATE hosts SET ssh_key_fingerprint = NULL, ssh_key_type = NULL, @@ -617,22 +698,21 @@ async def delete_host_ssh_key(host_id: str, db: Session = Depends(get_db), curre ssh_key_comment = NULL, updated_at = :updated_at WHERE id = :id - """), { - "id": host_uuid, - "updated_at": datetime.utcnow() - }) - + """ + ), + {"id": host_uuid, "updated_at": datetime.utcnow()}, + ) + db.commit() - + logger.info(f"Deleted SSH key from host {host_id}") return {"message": "SSH key deleted successfully"} - + except HTTPException: raise except Exception as e: logger.error(f"Failed to delete SSH key from host: {e}") db.rollback() raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to delete SSH key" - ) \ No newline at end of file + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to delete SSH key" + ) diff --git a/backend/app/routes/integration_metrics.py b/backend/app/routes/integration_metrics.py index a3a7a4f5..02b152de 100644 --- a/backend/app/routes/integration_metrics.py +++ b/backend/app/routes/integration_metrics.py @@ -2,6 +2,7 @@ Integration Metrics API Routes Provides endpoints for monitoring integration performance and health """ + from fastapi import APIRouter, HTTPException, Query, Response, Depends from typing import Optional, Dict, Any import json @@ -13,19 +14,20 @@ router = APIRouter() + @router.get("/health") async def integration_health(): """Get integration health status - no auth required""" try: stats = metrics_collector.get_current_stats() summaries = metrics_collector.get_metrics_summary(hours=1) - + # Calculate overall health score total_operations = sum(summary.total_requests for summary in summaries.values()) total_errors = sum(summary.failed_requests for summary in summaries.values()) - + error_rate = (total_errors / total_operations * 100) if total_operations > 0 else 0 - + # Determine health status if error_rate < 1: health_status = "healthy" @@ -33,7 +35,7 @@ async def integration_health(): health_status = "degraded" else: health_status = "unhealthy" - + return { "status": health_status, "timestamp": datetime.utcnow().isoformat(), @@ -41,28 +43,25 @@ async def integration_health(): "total_operations_1h": total_operations, "error_rate_1h": round(error_rate, 2), "recent_errors": len(stats.get("recent_errors", [])), - "active_metrics": stats["total_metrics"] + "active_metrics": stats["total_metrics"], }, - "top_operations": stats.get("top_operations", {}) + "top_operations": stats.get("top_operations", {}), } except Exception as e: - return { - "status": "unknown", - "timestamp": datetime.utcnow().isoformat(), - "error": str(e) - } + return {"status": "unknown", "timestamp": datetime.utcnow().isoformat(), "error": str(e)} + @router.get("/stats") @require_admin() async def get_integration_stats( hours: int = Query(1, ge=1, le=168, description="Hours of data to analyze"), - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Get detailed integration statistics""" try: stats = metrics_collector.get_current_stats() summaries = metrics_collector.get_metrics_summary(hours=hours) - + return { "period_hours": hours, "timestamp": datetime.utcnow().isoformat(), @@ -76,20 +75,21 @@ async def get_integration_stats( "avg_duration_ms": round(summary.average_duration * 1000, 2), "min_duration_ms": round(summary.min_duration * 1000, 2), "max_duration_ms": round(summary.max_duration * 1000, 2), - "p95_duration_ms": round(summary.p95_duration * 1000, 2) + "p95_duration_ms": round(summary.p95_duration * 1000, 2), } for op, summary in summaries.items() - } + }, } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get integration stats: {e}") + @router.get("/metrics") @require_admin() async def get_metrics( format: str = Query("json", regex="^(json|prometheus)$"), operation: Optional[str] = Query(None, description="Filter by specific operation"), - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Export metrics in various formats""" try: @@ -98,103 +98,133 @@ async def get_metrics( return Response( content=metrics_data, media_type="text/plain; version=0.0.4; charset=utf-8", - headers={"Content-Type": "text/plain; version=0.0.4; charset=utf-8"} + headers={"Content-Type": "text/plain; version=0.0.4; charset=utf-8"}, ) else: # Filter metrics if operation specified all_metrics = list(metrics_collector.metrics) if operation: all_metrics = [m for m in all_metrics if m.operation == operation] - + # Convert to serializable format metrics_data = [] for metric in all_metrics[-1000:]: # Last 1000 metrics - metrics_data.append({ - "timestamp": datetime.fromtimestamp(metric.timestamp).isoformat(), - "metric_type": metric.metric_type, - "operation": metric.operation, - "value": metric.value, - "success": metric.success, - "error": metric.error, - "labels": metric.labels - }) - + metrics_data.append( + { + "timestamp": datetime.fromtimestamp(metric.timestamp).isoformat(), + "metric_type": metric.metric_type, + "operation": metric.operation, + "value": metric.value, + "success": metric.success, + "error": metric.error, + "labels": metric.labels, + } + ) + return { "total_metrics": len(all_metrics), "returned_metrics": len(metrics_data), - "metrics": metrics_data + "metrics": metrics_data, } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to export metrics: {e}") + @router.get("/performance") @require_admin() -async def get_performance_overview( - current_user: Dict[str, Any] = Depends(get_current_user) -): +async def get_performance_overview(current_user: Dict[str, Any] = Depends(get_current_user)): """Get performance overview dashboard data""" try: # Get metrics for different time periods last_hour = metrics_collector.get_metrics_summary(hours=1) last_day = metrics_collector.get_metrics_summary(hours=24) - + # Calculate trends performance_data = {} - + for operation in set(list(last_hour.keys()) + list(last_day.keys())): hour_data = last_hour.get(operation) day_data = last_day.get(operation) - + performance_data[operation] = { "last_hour": { "requests": hour_data.total_requests if hour_data else 0, "error_rate": round(hour_data.error_rate, 2) if hour_data else 0, - "avg_duration_ms": round(hour_data.average_duration * 1000, 2) if hour_data else 0, - "p95_duration_ms": round(hour_data.p95_duration * 1000, 2) if hour_data else 0 + "avg_duration_ms": ( + round(hour_data.average_duration * 1000, 2) if hour_data else 0 + ), + "p95_duration_ms": round(hour_data.p95_duration * 1000, 2) if hour_data else 0, }, "last_24h": { "requests": day_data.total_requests if day_data else 0, "error_rate": round(day_data.error_rate, 2) if day_data else 0, - "avg_duration_ms": round(day_data.average_duration * 1000, 2) if day_data else 0, - "p95_duration_ms": round(day_data.p95_duration * 1000, 2) if day_data else 0 - } + "avg_duration_ms": ( + round(day_data.average_duration * 1000, 2) if day_data else 0 + ), + "p95_duration_ms": round(day_data.p95_duration * 1000, 2) if day_data else 0, + }, } - + # Calculate trends if hour_data and day_data: performance_data[operation]["trends"] = { - "error_rate_trend": "up" if hour_data.error_rate > day_data.error_rate else "down" if hour_data.error_rate < day_data.error_rate else "stable", - "performance_trend": "better" if hour_data.average_duration < day_data.average_duration else "worse" if hour_data.average_duration > day_data.average_duration else "stable" + "error_rate_trend": ( + "up" + if hour_data.error_rate > day_data.error_rate + else "down" if hour_data.error_rate < day_data.error_rate else "stable" + ), + "performance_trend": ( + "better" + if hour_data.average_duration < day_data.average_duration + else ( + "worse" + if hour_data.average_duration > day_data.average_duration + else "stable" + ) + ), } - + return { "timestamp": datetime.utcnow().isoformat(), "performance_data": performance_data, "summary": { - "total_operations_1h": sum(data["last_hour"]["requests"] for data in performance_data.values()), - "total_operations_24h": sum(data["last_24h"]["requests"] for data in performance_data.values()), - "avg_error_rate_1h": sum(data["last_hour"]["error_rate"] for data in performance_data.values()) / len(performance_data) if performance_data else 0, - "avg_error_rate_24h": sum(data["last_24h"]["error_rate"] for data in performance_data.values()) / len(performance_data) if performance_data else 0 - } + "total_operations_1h": sum( + data["last_hour"]["requests"] for data in performance_data.values() + ), + "total_operations_24h": sum( + data["last_24h"]["requests"] for data in performance_data.values() + ), + "avg_error_rate_1h": ( + sum(data["last_hour"]["error_rate"] for data in performance_data.values()) + / len(performance_data) + if performance_data + else 0 + ), + "avg_error_rate_24h": ( + sum(data["last_24h"]["error_rate"] for data in performance_data.values()) + / len(performance_data) + if performance_data + else 0 + ), + }, } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get performance overview: {e}") + @router.post("/cleanup") @require_admin() -async def cleanup_old_metrics( - current_user: Dict[str, Any] = Depends(get_current_user) -): +async def cleanup_old_metrics(current_user: Dict[str, Any] = Depends(get_current_user)): """Manually trigger cleanup of old metrics""" try: initial_count = len(metrics_collector.metrics) metrics_collector.cleanup_old_metrics() final_count = len(metrics_collector.metrics) - + return { "message": "Metrics cleanup completed", "metrics_removed": initial_count - final_count, - "metrics_remaining": final_count + "metrics_remaining": final_count, } except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to cleanup metrics: {e}") \ No newline at end of file + raise HTTPException(status_code=500, detail=f"Failed to cleanup metrics: {e}") diff --git a/backend/app/routes/mfa.py b/backend/app/routes/mfa.py index 16174320..acfe8d6d 100644 --- a/backend/app/routes/mfa.py +++ b/backend/app/routes/mfa.py @@ -30,17 +30,19 @@ def get_client_ip(request: Request) -> str: # Request/Response Models class MFAEnrollmentRequest(BaseModel): """Request to enroll in MFA""" + verify_password: str class MFAEnrollmentResponse(BaseModel): """Response for MFA enrollment""" + success: bool qr_code: Optional[str] = None backup_codes: Optional[List[str]] = None error_message: Optional[str] = None - - @validator('backup_codes') + + @validator("backup_codes") def mask_sensitive_data(cls, v): # In production, consider not returning backup codes in API response # Instead, display them once and require user to save them @@ -49,12 +51,13 @@ def mask_sensitive_data(cls, v): class MFAValidationRequest(BaseModel): """Request to validate MFA code""" + code: str - - @validator('code') + + @validator("code") def validate_code_format(cls, v): # Remove spaces and validate format - code = v.strip().replace(' ', '') + code = v.strip().replace(" ", "") if not code: raise ValueError("MFA code cannot be empty") if not (len(code) == 6 and code.isdigit()) and not (len(code) == 8 and code.isalnum()): @@ -64,6 +67,7 @@ def validate_code_format(cls, v): class MFAStatusResponse(BaseModel): """MFA status for user""" + mfa_enabled: bool totp_enabled: bool backup_codes_available: int @@ -74,6 +78,7 @@ class MFAStatusResponse(BaseModel): class BackupCodesRegenerateResponse(BaseModel): """Response for backup code regeneration""" + success: bool backup_codes: Optional[List[str]] = None error_message: Optional[str] = None @@ -81,12 +86,13 @@ class BackupCodesRegenerateResponse(BaseModel): class MFADisableRequest(BaseModel): """Request to disable MFA""" + verify_password: str confirm_disable: bool = False # Audit Logging -async def log_mfa_action( +def log_mfa_action( db: Session, user_id: int, action: str, @@ -94,7 +100,7 @@ async def log_mfa_action( ip_address: str, user_agent: str, method: Optional[str] = None, - details: Optional[Dict] = None + details: Optional[Dict] = None, ): """Log MFA action to audit table""" try: @@ -105,17 +111,17 @@ async def log_mfa_action( success=success, ip_address=ip_address, user_agent=user_agent, - details=details or {} + details=details or {}, ) db.add(audit_entry) db.commit() - + # Also log to security audit logger status_text = "SUCCESS" if success else "FAILED" audit_logger.log_security_event( f"MFA_{action.upper()}_{status_text}", f"User {user_id} {action} MFA - Method: {method or 'N/A'}", - ip_address + ip_address, ) except Exception as e: logger.error(f"Failed to log MFA action: {e}") @@ -123,42 +129,43 @@ async def log_mfa_action( @router.get("/status", response_model=MFAStatusResponse) async def get_mfa_status( - current_user: Dict[str, Any] = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: Dict[str, Any] = Depends(get_current_user), db: Session = Depends(get_db) ): """Get user's MFA status""" try: # Get user MFA data from database - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT mfa_enabled, mfa_secret, backup_codes, last_mfa_use, mfa_enrolled_at FROM users WHERE id = :user_id - """), {"user_id": current_user["id"]}) - + """ + ), + {"user_id": current_user["id"]}, + ) + user_data = result.fetchone() if not user_data: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) - + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + backup_codes_count = len(user_data.backup_codes) if user_data.backup_codes else 0 - + return MFAStatusResponse( mfa_enabled=bool(user_data.mfa_enabled), totp_enabled=bool(user_data.mfa_secret), backup_codes_available=backup_codes_count, last_mfa_use=user_data.last_mfa_use, enrollment_date=user_data.mfa_enrolled_at, - supported_methods=["totp", "backup_codes"] + supported_methods=["totp", "backup_codes"], ) - + except HTTPException: raise except Exception as e: logger.error(f"Failed to get MFA status for user {current_user['id']}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve MFA status" + detail="Failed to retrieve MFA status", ) @@ -167,103 +174,119 @@ async def enroll_mfa( request: MFAEnrollmentRequest, http_request: Request, current_user: Dict[str, Any] = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Enroll user in MFA with TOTP and backup codes""" client_ip = get_client_ip(http_request) user_agent = http_request.headers.get("user-agent", "") - + try: # Verify user's password first from ..auth import pwd_context - - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT hashed_password, mfa_enabled FROM users WHERE id = :user_id - """), {"user_id": current_user["id"]}) - + """ + ), + {"user_id": current_user["id"]}, + ) + user_data = result.fetchone() if not user_data: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) - + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + if not pwd_context.verify(request.verify_password, user_data.hashed_password): await log_mfa_action( - db, current_user["id"], "enroll_attempt", False, - client_ip, user_agent, details={"reason": "invalid_password"} - ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid password" + db, + current_user["id"], + "enroll_attempt", + False, + client_ip, + user_agent, + details={"reason": "invalid_password"}, ) - + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password") + if user_data.mfa_enabled: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="MFA is already enabled for this user" + detail="MFA is already enabled for this user", ) - + # Enroll user in MFA enrollment_result = mfa_service.enroll_user_mfa(current_user["username"]) - + if not enrollment_result.success: await log_mfa_action( - db, current_user["id"], "enroll", False, - client_ip, user_agent, details={"error": enrollment_result.error_message} + db, + current_user["id"], + "enroll", + False, + client_ip, + user_agent, + details={"error": enrollment_result.error_message}, ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=enrollment_result.error_message + detail=enrollment_result.error_message, ) - + # Encrypt and store MFA secret encrypted_secret = mfa_service.encrypt_mfa_secret(enrollment_result.secret_key) - + # Hash backup codes for storage hashed_backup_codes = [ - mfa_service.hash_backup_code(code) - for code in enrollment_result.backup_codes + mfa_service.hash_backup_code(code) for code in enrollment_result.backup_codes ] - + # Update user record - db.execute(text(""" + db.execute( + text( + """ UPDATE users SET mfa_secret = :encrypted_secret, backup_codes = :backup_codes, mfa_enrolled_at = CURRENT_TIMESTAMP, mfa_recovery_codes_generated_at = CURRENT_TIMESTAMP WHERE id = :user_id - """), { - "encrypted_secret": encrypted_secret, - "backup_codes": hashed_backup_codes, - "user_id": current_user["id"] - }) + """ + ), + { + "encrypted_secret": encrypted_secret, + "backup_codes": hashed_backup_codes, + "user_id": current_user["id"], + }, + ) db.commit() - + await log_mfa_action( - db, current_user["id"], "enroll", True, - client_ip, user_agent, method="totp" + db, current_user["id"], "enroll", True, client_ip, user_agent, method="totp" ) - + return MFAEnrollmentResponse( success=True, qr_code=enrollment_result.qr_code_data, - backup_codes=enrollment_result.backup_codes + backup_codes=enrollment_result.backup_codes, ) - + except HTTPException: raise except Exception as e: logger.error(f"MFA enrollment failed for user {current_user['id']}: {e}") await log_mfa_action( - db, current_user["id"], "enroll", False, - client_ip, user_agent, details={"error": str(e)} + db, + current_user["id"], + "enroll", + False, + client_ip, + user_agent, + details={"error": str(e)}, ) raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="MFA enrollment failed" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="MFA enrollment failed" ) @@ -272,98 +295,128 @@ async def validate_mfa_code( request: MFAValidationRequest, http_request: Request, current_user: Dict[str, Any] = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Validate MFA code for already enrolled user""" client_ip = get_client_ip(http_request) user_agent = http_request.headers.get("user-agent", "") - + try: # Get user MFA data - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT mfa_enabled, mfa_secret, backup_codes FROM users WHERE id = :user_id - """), {"user_id": current_user["id"]}) - + """ + ), + {"user_id": current_user["id"]}, + ) + user_data = result.fetchone() if not user_data or not user_data.mfa_enabled: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="MFA is not enabled for this user" + status_code=status.HTTP_400_BAD_REQUEST, detail="MFA is not enabled for this user" ) - + # Get recently used codes for replay protection - recent_codes = db.execute(text(""" + recent_codes = db.execute( + text( + """ SELECT code_hash FROM mfa_used_codes WHERE user_id = :user_id AND used_at > NOW() - INTERVAL '5 minutes' - """), {"user_id": current_user["id"]}).fetchall() - + """ + ), + {"user_id": current_user["id"]}, + ).fetchall() + used_codes_cache = {row.code_hash for row in recent_codes} - + # Validate MFA code validation_result = mfa_service.validate_mfa_code( - user_data.mfa_secret, - user_data.backup_codes or [], - request.code, - used_codes_cache + user_data.mfa_secret, user_data.backup_codes or [], request.code, used_codes_cache ) - + if validation_result.valid: # Update last MFA use - db.execute(text(""" + db.execute( + text( + """ UPDATE users SET last_mfa_use = CURRENT_TIMESTAMP WHERE id = :user_id - """), {"user_id": current_user["id"]}) - + """ + ), + {"user_id": current_user["id"]}, + ) + # Record used code for replay protection (TOTP only) if validation_result.method_used.value == "totp": import hashlib - code_hash = hashlib.sha256(f"{request.code}_{int(datetime.now().timestamp() // 30)}".encode()).hexdigest() - used_code = MFAUsedCodes( - user_id=current_user["id"], - code_hash=code_hash - ) + + code_hash = hashlib.sha256( + f"{request.code}_{int(datetime.now().timestamp() // 30)}".encode() + ).hexdigest() + used_code = MFAUsedCodes(user_id=current_user["id"], code_hash=code_hash) db.add(used_code) - + # Remove used backup code if applicable if validation_result.backup_code_used: - updated_codes = [code for code in user_data.backup_codes - if code != validation_result.backup_code_used] - db.execute(text(""" + updated_codes = [ + code + for code in user_data.backup_codes + if code != validation_result.backup_code_used + ] + db.execute( + text( + """ UPDATE users SET backup_codes = :backup_codes WHERE id = :user_id - """), {"backup_codes": updated_codes, "user_id": current_user["id"]}) - + """ + ), + {"backup_codes": updated_codes, "user_id": current_user["id"]}, + ) + db.commit() - + await log_mfa_action( - db, current_user["id"], "validate", True, - client_ip, user_agent, method=validation_result.method_used.value + db, + current_user["id"], + "validate", + True, + client_ip, + user_agent, + method=validation_result.method_used.value, ) - + return {"success": True, "method": validation_result.method_used.value} else: await log_mfa_action( - db, current_user["id"], "validate", False, - client_ip, user_agent, details={"error": validation_result.error_message} + db, + current_user["id"], + "validate", + False, + client_ip, + user_agent, + details={"error": validation_result.error_message}, ) - - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid MFA code" - ) - + + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid MFA code") + except HTTPException: raise except Exception as e: logger.error(f"MFA validation failed for user {current_user['id']}: {e}") await log_mfa_action( - db, current_user["id"], "validate", False, - client_ip, user_agent, details={"error": str(e)} + db, + current_user["id"], + "validate", + False, + client_ip, + user_agent, + details={"error": str(e)}, ) raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="MFA validation failed" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="MFA validation failed" ) @@ -372,75 +425,90 @@ async def enable_mfa( request: MFAValidationRequest, http_request: Request, current_user: Dict[str, Any] = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Enable MFA after successful enrollment verification""" client_ip = get_client_ip(http_request) user_agent = http_request.headers.get("user-agent", "") - + try: # Verify the TOTP code to confirm enrollment - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT mfa_enabled, mfa_secret, backup_codes FROM users WHERE id = :user_id - """), {"user_id": current_user["id"]}) - + """ + ), + {"user_id": current_user["id"]}, + ) + user_data = result.fetchone() if not user_data or not user_data.mfa_secret: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="MFA enrollment not found. Please enroll first." + detail="MFA enrollment not found. Please enroll first.", ) - + if user_data.mfa_enabled: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="MFA is already enabled" + status_code=status.HTTP_400_BAD_REQUEST, detail="MFA is already enabled" ) - + # Validate the provided code validation_result = mfa_service.validate_mfa_code( - user_data.mfa_secret, - user_data.backup_codes or [], - request.code + user_data.mfa_secret, user_data.backup_codes or [], request.code ) - + if not validation_result.valid: await log_mfa_action( - db, current_user["id"], "enable_attempt", False, - client_ip, user_agent, details={"reason": "invalid_code"} + db, + current_user["id"], + "enable_attempt", + False, + client_ip, + user_agent, + details={"reason": "invalid_code"}, ) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid MFA code. Please try again." + detail="Invalid MFA code. Please try again.", ) - + # Enable MFA - db.execute(text(""" + db.execute( + text( + """ UPDATE users SET mfa_enabled = true, last_mfa_use = CURRENT_TIMESTAMP WHERE id = :user_id - """), {"user_id": current_user["id"]}) + """ + ), + {"user_id": current_user["id"]}, + ) db.commit() - + await log_mfa_action( - db, current_user["id"], "enable", True, - client_ip, user_agent, method="totp" + db, current_user["id"], "enable", True, client_ip, user_agent, method="totp" ) - + return {"success": True, "message": "MFA enabled successfully"} - + except HTTPException: raise except Exception as e: logger.error(f"MFA enable failed for user {current_user['id']}: {e}") await log_mfa_action( - db, current_user["id"], "enable", False, - client_ip, user_agent, details={"error": str(e)} + db, + current_user["id"], + "enable", + False, + client_ip, + user_agent, + details={"error": str(e)}, ) raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to enable MFA" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to enable MFA" ) @@ -449,83 +517,93 @@ async def regenerate_backup_codes( request: MFAValidationRequest, http_request: Request, current_user: Dict[str, Any] = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Regenerate backup codes after MFA validation""" client_ip = get_client_ip(http_request) user_agent = http_request.headers.get("user-agent", "") - + try: # Verify user has MFA enabled - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT mfa_enabled, mfa_secret, backup_codes FROM users WHERE id = :user_id - """), {"user_id": current_user["id"]}) - + """ + ), + {"user_id": current_user["id"]}, + ) + user_data = result.fetchone() if not user_data or not user_data.mfa_enabled: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="MFA is not enabled" + status_code=status.HTTP_400_BAD_REQUEST, detail="MFA is not enabled" ) - + # Validate MFA code first validation_result = mfa_service.validate_mfa_code( - user_data.mfa_secret, - user_data.backup_codes or [], - request.code + user_data.mfa_secret, user_data.backup_codes or [], request.code ) - + if not validation_result.valid: await log_mfa_action( - db, current_user["id"], "regenerate_backup_codes", False, - client_ip, user_agent, details={"reason": "invalid_mfa_code"} - ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid MFA code" + db, + current_user["id"], + "regenerate_backup_codes", + False, + client_ip, + user_agent, + details={"reason": "invalid_mfa_code"}, ) - + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid MFA code") + # Generate new backup codes new_backup_codes = mfa_service.regenerate_backup_codes(current_user["username"]) - hashed_backup_codes = [ - mfa_service.hash_backup_code(code) - for code in new_backup_codes - ] - + hashed_backup_codes = [mfa_service.hash_backup_code(code) for code in new_backup_codes] + # Update database - db.execute(text(""" + db.execute( + text( + """ UPDATE users SET backup_codes = :backup_codes, mfa_recovery_codes_generated_at = CURRENT_TIMESTAMP WHERE id = :user_id - """), { - "backup_codes": hashed_backup_codes, - "user_id": current_user["id"] - }) + """ + ), + {"backup_codes": hashed_backup_codes, "user_id": current_user["id"]}, + ) db.commit() - + await log_mfa_action( - db, current_user["id"], "regenerate_backup_codes", True, - client_ip, user_agent, method="backup_codes" + db, + current_user["id"], + "regenerate_backup_codes", + True, + client_ip, + user_agent, + method="backup_codes", ) - - return BackupCodesRegenerateResponse( - success=True, - backup_codes=new_backup_codes - ) - + + return BackupCodesRegenerateResponse(success=True, backup_codes=new_backup_codes) + except HTTPException: raise except Exception as e: logger.error(f"Backup code regeneration failed for user {current_user['id']}: {e}") await log_mfa_action( - db, current_user["id"], "regenerate_backup_codes", False, - client_ip, user_agent, details={"error": str(e)} + db, + current_user["id"], + "regenerate_backup_codes", + False, + client_ip, + user_agent, + details={"error": str(e)}, ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to regenerate backup codes" + detail="Failed to regenerate backup codes", ) @@ -534,83 +612,96 @@ async def disable_mfa( request: MFADisableRequest, http_request: Request, current_user: Dict[str, Any] = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Disable MFA for user (requires password confirmation)""" client_ip = get_client_ip(http_request) user_agent = http_request.headers.get("user-agent", "") - + try: if not request.confirm_disable: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Must confirm MFA disable action" + status_code=status.HTTP_400_BAD_REQUEST, detail="Must confirm MFA disable action" ) - + # Verify password from ..auth import pwd_context - - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT hashed_password, mfa_enabled FROM users WHERE id = :user_id - """), {"user_id": current_user["id"]}) - + """ + ), + {"user_id": current_user["id"]}, + ) + user_data = result.fetchone() if not user_data: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="User not found" - ) - + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + if not pwd_context.verify(request.verify_password, user_data.hashed_password): await log_mfa_action( - db, current_user["id"], "disable_attempt", False, - client_ip, user_agent, details={"reason": "invalid_password"} - ) - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid password" + db, + current_user["id"], + "disable_attempt", + False, + client_ip, + user_agent, + details={"reason": "invalid_password"}, ) - + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password") + if not user_data.mfa_enabled: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="MFA is not enabled" + status_code=status.HTTP_400_BAD_REQUEST, detail="MFA is not enabled" ) - + # Disable MFA and clear secrets - db.execute(text(""" + db.execute( + text( + """ UPDATE users SET mfa_enabled = false, mfa_secret = NULL, backup_codes = NULL, last_mfa_use = NULL WHERE id = :user_id - """), {"user_id": current_user["id"]}) - + """ + ), + {"user_id": current_user["id"]}, + ) + # Clear used codes - db.execute(text(""" + db.execute( + text( + """ DELETE FROM mfa_used_codes WHERE user_id = :user_id - """), {"user_id": current_user["id"]}) - - db.commit() - - await log_mfa_action( - db, current_user["id"], "disable", True, - client_ip, user_agent + """ + ), + {"user_id": current_user["id"]}, ) - + + db.commit() + + await log_mfa_action(db, current_user["id"], "disable", True, client_ip, user_agent) + return {"success": True, "message": "MFA disabled successfully"} - + except HTTPException: raise except Exception as e: logger.error(f"MFA disable failed for user {current_user['id']}: {e}") await log_mfa_action( - db, current_user["id"], "disable", False, - client_ip, user_agent, details={"error": str(e)} + db, + current_user["id"], + "disable", + False, + client_ip, + user_agent, + details={"error": str(e)}, ) raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to disable MFA" - ) \ No newline at end of file + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to disable MFA" + ) diff --git a/backend/app/routes/monitoring.py b/backend/app/routes/monitoring.py index 9394acc7..4ad9acd4 100644 --- a/backend/app/routes/monitoring.py +++ b/backend/app/routes/monitoring.py @@ -1,6 +1,7 @@ """ Host Monitoring API Routes """ + from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from sqlalchemy.orm import Session from typing import List, Optional @@ -17,65 +18,70 @@ from pydantic import BaseModel + class HostCheckRequest(BaseModel): host_id: str + @router.post("/hosts/check") async def check_host_status( request: HostCheckRequest, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Check status of a specific host """ try: from sqlalchemy import text - + # Get host details - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, hostname, ip_address, port, username, auth_method FROM hosts WHERE id = :id - """), {"id": request.host_id}) - + """ + ), + {"id": request.host_id}, + ) + host_row = result.fetchone() if not host_row: raise HTTPException(status_code=404, detail="Host not found") - + host_data = { - 'id': str(host_row.id), - 'hostname': host_row.hostname, - 'ip_address': str(host_row.ip_address), - 'port': host_row.port or 22, - 'username': host_row.username, - 'auth_method': host_row.auth_method, + "id": str(host_row.id), + "hostname": host_row.hostname, + "ip_address": str(host_row.ip_address), + "port": host_row.port or 22, + "username": host_row.username, + "auth_method": host_row.auth_method, # NOTE: encrypted_credentials removed - using centralized auth service } - + # Perform comprehensive check with DB connection for credential access check_result = await host_monitor.comprehensive_host_check(host_data, db) - + # Update database with new status - await host_monitor.update_host_status( - db, request.host_id, check_result['status'] - ) - + await host_monitor.update_host_status(db, request.host_id, check_result["status"]) + return { "host_id": request.host_id, - "status": check_result['status'], - "ping_success": check_result['ping_success'], - "port_open": check_result['port_open'], - "ssh_accessible": check_result['ssh_accessible'], - "response_time_ms": check_result['response_time_ms'], - "error_message": check_result['error_message'], - "timestamp": check_result['timestamp'], + "status": check_result["status"], + "ping_success": check_result["ping_success"], + "port_open": check_result["port_open"], + "ssh_accessible": check_result["ssh_accessible"], + "response_time_ms": check_result["response_time_ms"], + "error_message": check_result["error_message"], + "timestamp": check_result["timestamp"], # SSH credential information - "ssh_credentials_used": check_result.get('ssh_credentials_source'), - "ssh_username": check_result.get('ssh_username'), - "ready_for_scans": check_result['ssh_accessible'], - "credential_details": check_result.get('credential_details') + "ssh_credentials_used": check_result.get("ssh_credentials_source"), + "ssh_username": check_result.get("ssh_username"), + "ready_for_scans": check_result["ssh_accessible"], + "credential_details": check_result.get("credential_details"), } - + except HTTPException: raise except Exception as e: @@ -87,7 +93,7 @@ async def check_host_status( async def check_all_hosts_status( background_tasks: BackgroundTasks, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Check status of all hosts (runs in background) @@ -95,12 +101,9 @@ async def check_all_hosts_status( try: # Run monitoring in background background_tasks.add_task(host_monitor.monitor_all_hosts, db) - - return { - "message": "Host monitoring started in background", - "status": "running" - } - + + return {"message": "Host monitoring started in background", "status": "running"} + except Exception as e: logger.error(f"Error starting host monitoring: {e}") raise HTTPException(status_code=500, detail="Failed to start host monitoring") @@ -108,36 +111,41 @@ async def check_all_hosts_status( @router.get("/hosts/status") async def get_hosts_status_summary( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """ Get summary of all host statuses """ try: from sqlalchemy import text - - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT status, COUNT(*) as count FROM hosts WHERE is_active = true GROUP BY status - """)) - + """ + ) + ) + status_counts = {} total = 0 for row in result: status_counts[row.status] = row.count total += row.count - + return { "total_hosts": total, "status_breakdown": status_counts, - "online_percentage": round((status_counts.get('online', 0) / total * 100) if total > 0 else 0, 1) + "online_percentage": round( + (status_counts.get("online", 0) / total * 100) if total > 0 else 0, 1 + ), } - + except Exception as e: logger.error(f"Error getting host status summary: {e}") raise HTTPException(status_code=500, detail="Failed to get host status summary") @@ -145,39 +153,42 @@ async def get_hosts_status_summary( @router.post("/hosts/{host_id}/ping") async def ping_host( - host_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + host_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """ Simple ping test for a specific host """ try: from sqlalchemy import text - + # Get host IP - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT ip_address FROM hosts WHERE id = :id - """), {"id": host_id}) - + """ + ), + {"id": host_id}, + ) + host_row = result.fetchone() if not host_row: raise HTTPException(status_code=404, detail="Host not found") - + ip_address = str(host_row.ip_address) - + # Perform ping ping_success = await host_monitor.ping_host(ip_address) - + return { "host_id": host_id, "ip_address": ip_address, "ping_success": ping_success, - "timestamp": host_monitor.__class__.__module__ + "timestamp": host_monitor.__class__.__module__, } - + except HTTPException: raise except Exception as e: logger.error(f"Error pinging host: {e}") - raise HTTPException(status_code=500, detail="Failed to ping host") \ No newline at end of file + raise HTTPException(status_code=500, detail="Failed to ping host") diff --git a/backend/app/routes/remediation_callback.py b/backend/app/routes/remediation_callback.py index 3315767b..8e393cee 100644 --- a/backend/app/routes/remediation_callback.py +++ b/backend/app/routes/remediation_callback.py @@ -2,6 +2,7 @@ AEGIS Remediation Callback Routes Handles remediation completion notifications from AEGIS """ + from fastapi import APIRouter, HTTPException, Header, Request, status, Depends from pydantic import BaseModel, Field, UUID4 from typing import List, Optional, Dict @@ -26,8 +27,8 @@ class RemediationResult(BaseModel): rule_name: str status: str = Field(..., pattern="^(success|failed|skipped)$") error_message: Optional[str] = None - - + + class RemediationCallbackRequest(BaseModel): remediation_job_id: UUID4 scan_id: UUID4 @@ -41,7 +42,7 @@ class RemediationCallbackRequest(BaseModel): results: List[RemediationResult] started_at: datetime completed_at: datetime - + @router.post("/api/v1/webhooks/remediation-complete") async def handle_remediation_callback( @@ -49,62 +50,56 @@ async def handle_remediation_callback( callback: RemediationCallbackRequest, x_openwatch_signature: Optional[str] = Header(None), x_hub_signature_256: Optional[str] = Header(None), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Handle remediation completion callback from AEGIS""" - + # Verify webhook signature signature = x_openwatch_signature or x_hub_signature_256 if not signature: logger.warning("Remediation callback received without signature") raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Missing webhook signature" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing webhook signature" ) - + # Get webhook secret webhook_secret = settings.aegis_webhook_secret or "shared_webhook_secret" - + # Get raw body for signature verification body = await request.body() - + if not verify_webhook_signature(body.decode(), webhook_secret, signature): logger.error("Invalid webhook signature for remediation callback") raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid webhook signature" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid webhook signature" ) - + try: # Find the original scan - scan = db.query(Scan).filter( - Scan.id == str(callback.scan_id) - ).first() - + scan = db.query(Scan).filter(Scan.id == str(callback.scan_id)).first() + if not scan: logger.error(f"Scan not found for remediation callback: {callback.scan_id}") raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Scan {callback.scan_id} not found" + status_code=status.HTTP_404_NOT_FOUND, detail=f"Scan {callback.scan_id} not found" ) - + # Verify host matches if str(scan.host_id) != str(callback.openwatch_host_id): - logger.error(f"Host mismatch in remediation callback: {scan.host_id} != {callback.openwatch_host_id}") - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Host ID mismatch" + logger.error( + f"Host mismatch in remediation callback: {scan.host_id} != {callback.openwatch_host_id}" ) - + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Host ID mismatch") + # Update scan with remediation information scan.aegis_remediation_id = str(callback.remediation_job_id) scan.remediation_status = callback.status scan.remediation_completed_at = callback.completed_at - + # Store remediation results in scan metadata if not scan.metadata: scan.metadata = {} - + scan.metadata["remediation"] = { "job_id": str(callback.remediation_job_id), "status": callback.status, @@ -114,11 +109,11 @@ async def handle_remediation_callback( "skipped_rules": callback.skipped_rules, "started_at": callback.started_at.isoformat(), "completed_at": callback.completed_at.isoformat(), - "results": [r.dict() for r in callback.results] + "results": [r.dict() for r in callback.results], } - + db.commit() - + # Log audit event await log_audit_event( db=db, @@ -130,30 +125,30 @@ async def handle_remediation_callback( "remediation_job_id": str(callback.remediation_job_id), "status": callback.status, "successful_rules": callback.successful_rules, - "failed_rules": callback.failed_rules + "failed_rules": callback.failed_rules, }, - ip_address="127.0.0.1" # Internal system + ip_address="127.0.0.1", # Internal system ) - + logger.info(f"Remediation callback processed for scan {scan.id}: {callback.status}") - + # Check if we should trigger a verification scan if callback.status == "completed" and callback.successful_rules > 0: # TODO: Trigger verification scan logger.info(f"Verification scan should be triggered for host {scan.host_id}") - + return { "status": "success", "message": "Remediation callback processed successfully", "scan_id": str(scan.id), - "verification_scan_needed": callback.successful_rules > 0 + "verification_scan_needed": callback.successful_rules > 0, } - + except HTTPException: raise except Exception as e: logger.error(f"Error processing remediation callback: {e}", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to process remediation callback" - ) \ No newline at end of file + detail="Failed to process remediation callback", + ) diff --git a/backend/app/routes/rule_scanning.py b/backend/app/routes/rule_scanning.py index e8110e55..289882f0 100644 --- a/backend/app/routes/rule_scanning.py +++ b/backend/app/routes/rule_scanning.py @@ -2,6 +2,7 @@ Rule-Specific Scanning API Routes Handles targeted scanning of specific SCAP rules """ + import logging from typing import List, Optional from fastapi import APIRouter, Depends, HTTPException, status @@ -27,6 +28,7 @@ class RuleScanRequest(BaseModel): """Request model for rule-specific scanning""" + host_id: str content_id: int profile_id: str @@ -36,6 +38,7 @@ class RuleScanRequest(BaseModel): class RemediationVerificationRequest(BaseModel): """Request model for remediation verification""" + host_id: str content_id: int aegis_remediation_id: str @@ -47,41 +50,48 @@ class RemediationVerificationRequest(BaseModel): async def scan_specific_rules( request: RuleScanRequest, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Scan specific SCAP rules on a host""" try: - logger.info(f"Rule-specific scan requested by {current_user['username']} for {len(request.rule_ids)} rules") - + logger.info( + f"Rule-specific scan requested by {current_user['username']} for {len(request.rule_ids)} rules" + ) + # Get SCAP content file path - content_result = db.execute(text(""" + content_result = db.execute( + text( + """ SELECT file_path FROM scap_content WHERE id = :id - """), {"id": request.content_id}).fetchone() - + """ + ), + {"id": request.content_id}, + ).fetchone() + if not content_result: raise HTTPException(status_code=404, detail="SCAP content not found") - + # Perform rule-specific scan scan_results = await rule_scanner.scan_specific_rules( host_id=request.host_id, content_path=content_result.file_path, profile_id=request.profile_id, rule_ids=request.rule_ids, - connection_params=request.connection_params + connection_params=request.connection_params, ) - + # Store scan results in database await _store_rule_scan_results(db, scan_results) - + # Generate remediation recommendations failed_rules = [ {"rule_id": rule["rule_id"], "severity": rule["severity"]} for rule in scan_results["rule_results"] if rule["result"] == "fail" ] - + remediation_priorities = framework_mapper.get_remediation_priorities(failed_rules) - + return { "scan_results": scan_results, "remediation_recommendations": remediation_priorities[:10], # Top 10 priorities @@ -90,10 +100,12 @@ async def scan_specific_rules( "passed": scan_results["passed_rules"], "failed": scan_results["failed_rules"], "compliance_score": scan_results.get("compliance_score", 0), - "automated_remediation_available": sum(1 for r in remediation_priorities if r["automated_remediation"]) - } + "automated_remediation_available": sum( + 1 for r in remediation_priorities if r["automated_remediation"] + ), + }, } - + except Exception as e: logger.error(f"Error in rule-specific scanning: {e}") raise HTTPException(status_code=500, detail=f"Scan failed: {str(e)}") @@ -105,36 +117,41 @@ async def rescan_failed_rules( content_id: int, connection_params: Optional[dict] = None, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Re-scan only failed rules from a previous scan""" try: # Get SCAP content file path - content_result = db.execute(text(""" + content_result = db.execute( + text( + """ SELECT file_path FROM scap_content WHERE id = :id - """), {"id": content_id}).fetchone() - + """ + ), + {"id": content_id}, + ).fetchone() + if not content_result: raise HTTPException(status_code=404, detail="SCAP content not found") - + # Perform failed rule re-scan scan_results = await rule_scanner.scan_failed_rules_from_previous_scan( previous_scan_id=previous_scan_id, content_path=content_result.file_path, - connection_params=connection_params + connection_params=connection_params, ) - + if "message" in scan_results: return scan_results # No failed rules to re-scan - + # Store results await _store_rule_scan_results(db, scan_results) - + return { "scan_results": scan_results, - "improvement_analysis": _analyze_improvement(previous_scan_id, scan_results) + "improvement_analysis": _analyze_improvement(previous_scan_id, scan_results), } - + except Exception as e: logger.error(f"Error re-scanning failed rules: {e}") raise HTTPException(status_code=500, detail=f"Re-scan failed: {str(e)}") @@ -144,32 +161,37 @@ async def rescan_failed_rules( async def verify_remediation( request: RemediationVerificationRequest, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Verify specific rules after AEGIS remediation""" try: # Get SCAP content file path - content_result = db.execute(text(""" + content_result = db.execute( + text( + """ SELECT file_path FROM scap_content WHERE id = :id - """), {"id": request.content_id}).fetchone() - + """ + ), + {"id": request.content_id}, + ).fetchone() + if not content_result: raise HTTPException(status_code=404, detail="SCAP content not found") - + # Perform remediation verification verification_report = await rule_scanner.verify_remediation( host_id=request.host_id, content_path=content_result.file_path, aegis_remediation_id=request.aegis_remediation_id, remediated_rules=request.remediated_rules, - connection_params=request.connection_params + connection_params=request.connection_params, ) - + # Update remediation plan status if exists await _update_remediation_plan_status(db, request.aegis_remediation_id, verification_report) - + return verification_report - + except Exception as e: logger.error(f"Error verifying remediation: {e}") raise HTTPException(status_code=500, detail=f"Verification failed: {str(e)}") @@ -181,7 +203,7 @@ async def get_rule_scan_history( host_id: Optional[str] = None, limit: int = 10, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Get scan history for a specific rule""" try: @@ -192,16 +214,16 @@ async def get_rule_scan_history( WHERE rule_id = :rule_id """ params = {"rule_id": rule_id} - + if host_id: query += " AND host_id = :host_id" params["host_id"] = host_id - + query += " ORDER BY scan_timestamp DESC LIMIT :limit" params["limit"] = limit - + db_results = db.execute(text(query), params).fetchall() - + history = [ { "scan_id": row.scan_id, @@ -209,46 +231,41 @@ async def get_rule_scan_history( "result": row.result, "severity": row.severity, "timestamp": row.scan_timestamp.isoformat(), - "duration_ms": row.duration_ms + "duration_ms": row.duration_ms, } for row in db_results ] - + # Get additional history from files if needed if len(history) < limit: - file_history = await rule_scanner.get_rule_scan_history(rule_id, host_id, limit - len(history)) + file_history = await rule_scanner.get_rule_scan_history( + rule_id, host_id, limit - len(history) + ) history.extend(file_history) - + # Get remediation guidance guidance = rule_scanner.get_rule_remediation_guidance(rule_id) - - return { - "rule_id": rule_id, - "scan_history": history, - "remediation_guidance": guidance - } - + + return {"rule_id": rule_id, "scan_history": history, "remediation_guidance": guidance} + except Exception as e: logger.error(f"Error getting rule scan history: {e}") raise HTTPException(status_code=500, detail="Failed to get rule history") @router.get("/rule/{rule_id}/compliance-info") -async def get_rule_compliance_info( - rule_id: str, - current_user: dict = Depends(get_current_user) -): +async def get_rule_compliance_info(rule_id: str, current_user: dict = Depends(get_current_user)): """Get compliance framework information for a specific rule""" try: # Get unified control information control = framework_mapper.get_unified_control(rule_id) - + if not control: raise HTTPException(status_code=404, detail="Rule not found in framework mappings") - + # Get AEGIS mapping if available aegis_mapping = aegis_mapper.get_aegis_mapping(rule_id) - + return { "rule_id": rule_id, "title": control.title, @@ -262,7 +279,7 @@ async def get_rule_compliance_info( "severity": mapping.severity, "maturity_level": mapping.maturity_level, "implementation_guidance": mapping.implementation_guidance, - "assessment_objectives": mapping.assessment_objectives + "assessment_objectives": mapping.assessment_objectives, } for mapping in control.frameworks ], @@ -271,12 +288,12 @@ async def get_rule_compliance_info( "aegis_rule_id": control.aegis_rule_id, "estimated_duration": aegis_mapping.estimated_duration if aegis_mapping else None, "requires_reboot": aegis_mapping.requires_reboot if aegis_mapping else False, - "category": aegis_mapping.rule_category if aegis_mapping else None + "category": aegis_mapping.rule_category if aegis_mapping else None, }, "tags": control.tags, - "categories": control.categories + "categories": control.categories, } - + except HTTPException: raise except Exception as e: @@ -290,58 +307,65 @@ async def create_remediation_plan( host_id: str, platform: str = "rhel9", db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Create remediation plan for failed rules from a scan""" try: # Get failed rules from scan failed_rules = [] - + # First try to get from rule_scan_history - history_results = db.execute(text(""" + history_results = db.execute( + text( + """ SELECT rule_id, severity FROM rule_scan_history WHERE scan_id = :scan_id AND result = 'fail' - """), {"scan_id": scan_id}).fetchall() - + """ + ), + {"scan_id": scan_id}, + ).fetchall() + if history_results: failed_rules = [ - {"rule_id": row.rule_id, "severity": row.severity} - for row in history_results + {"rule_id": row.rule_id, "severity": row.severity} for row in history_results ] else: # Fallback to getting from scan results table if exists - scan_result = db.execute(text(""" + scan_result = db.execute( + text( + """ SELECT sr.rule_details FROM scan_results sr JOIN scans s ON sr.scan_id = s.id WHERE s.id = :scan_id OR CAST(s.id AS TEXT) = :scan_id - """), {"scan_id": scan_id}).fetchone() - + """ + ), + {"scan_id": scan_id}, + ).fetchone() + if scan_result and scan_result.rule_details: import json + rule_details = json.loads(scan_result.rule_details) failed_rules = [ {"rule_id": rule["rule_id"], "severity": rule.get("severity", "medium")} for rule in rule_details if rule.get("result") == "fail" ] - + if not failed_rules: raise HTTPException(status_code=404, detail="No failed rules found for scan") - + # Create remediation plan plan = aegis_mapper.create_remediation_plan( - scan_id=scan_id, - host_id=host_id, - failed_rules=failed_rules, - platform=platform + scan_id=scan_id, host_id=host_id, failed_rules=failed_rules, platform=platform ) - + # Store plan in database await _store_remediation_plan(db, plan, current_user["id"]) - + # Generate AEGIS job request aegis_job_request = aegis_mapper.generate_aegis_job_request(plan) - + return { "plan": { "plan_id": plan.plan_id, @@ -352,12 +376,12 @@ async def create_remediation_plan( "dependencies_resolved": plan.dependencies_resolved, "rule_groups": { category: len(rules) for category, rules in plan.rule_groups.items() - } + }, }, "aegis_job_request": aegis_job_request, - "execution_ready": plan.dependencies_resolved + "execution_ready": plan.dependencies_resolved, } - + except HTTPException: raise except Exception as e: @@ -365,11 +389,13 @@ async def create_remediation_plan( raise HTTPException(status_code=500, detail="Failed to create remediation plan") -async def _store_rule_scan_results(db: Session, scan_results: dict): +def _store_rule_scan_results(db: Session, scan_results: dict): """Store rule scan results in database""" try: for rule_result in scan_results.get("rule_results", []): - db.execute(text(""" + db.execute( + text( + """ INSERT INTO rule_scan_history ( id, scan_id, host_id, rule_id, profile_id, result, severity, scan_output, compliance_frameworks, automated_remediation_available, @@ -379,34 +405,43 @@ async def _store_rule_scan_results(db: Session, scan_results: dict): :scan_output, :compliance_frameworks, :automated_remediation_available, :aegis_rule_id, NOW(), :duration_ms ) - """), { - "scan_id": scan_results["scan_id"], - "host_id": scan_results["host_id"], - "rule_id": rule_result["rule_id"], - "profile_id": scan_results.get("profile_id", ""), - "result": rule_result["result"], - "severity": rule_result.get("severity", "unknown"), - "scan_output": rule_result.get("scan_output", ""), - "compliance_frameworks": json.dumps(rule_result.get("compliance_frameworks", [])), - "automated_remediation_available": rule_result.get("automated_remediation_available", False), - "aegis_rule_id": rule_result.get("aegis_rule_id"), - "duration_ms": scan_results.get("duration_seconds", 0) * 1000 - }) - + """ + ), + { + "scan_id": scan_results["scan_id"], + "host_id": scan_results["host_id"], + "rule_id": rule_result["rule_id"], + "profile_id": scan_results.get("profile_id", ""), + "result": rule_result["result"], + "severity": rule_result.get("severity", "unknown"), + "scan_output": rule_result.get("scan_output", ""), + "compliance_frameworks": json.dumps( + rule_result.get("compliance_frameworks", []) + ), + "automated_remediation_available": rule_result.get( + "automated_remediation_available", False + ), + "aegis_rule_id": rule_result.get("aegis_rule_id"), + "duration_ms": scan_results.get("duration_seconds", 0) * 1000, + }, + ) + db.commit() logger.info(f"Stored {len(scan_results.get('rule_results', []))} rule scan results") - + except Exception as e: logger.error(f"Error storing rule scan results: {e}") db.rollback() -async def _store_remediation_plan(db: Session, plan, created_by: int): +def _store_remediation_plan(db: Session, plan, created_by: int): """Store remediation plan in database""" try: import json - - db.execute(text(""" + + db.execute( + text( + """ INSERT INTO remediation_plans ( id, plan_id, scan_id, host_id, total_rules, remediable_rules, remediated_rules, estimated_duration, requires_reboot, status, execution_order, rule_groups, @@ -416,57 +451,69 @@ async def _store_remediation_plan(db: Session, plan, created_by: int): :estimated_duration, :requires_reboot, 'pending', :execution_order, :rule_groups, :created_by, NOW() ) - """), { - "plan_id": plan.plan_id, - "scan_id": plan.scan_id, - "host_id": plan.host_id, - "total_rules": plan.total_rules, - "remediable_rules": plan.remediable_rules, - "estimated_duration": plan.estimated_duration, - "requires_reboot": plan.requires_reboot, - "execution_order": json.dumps(plan.execution_order), - "rule_groups": json.dumps({ - category: [mapping.scap_rule_id for mapping in mappings] - for category, mappings in plan.rule_groups.items() - }), - "created_by": created_by - }) - + """ + ), + { + "plan_id": plan.plan_id, + "scan_id": plan.scan_id, + "host_id": plan.host_id, + "total_rules": plan.total_rules, + "remediable_rules": plan.remediable_rules, + "estimated_duration": plan.estimated_duration, + "requires_reboot": plan.requires_reboot, + "execution_order": json.dumps(plan.execution_order), + "rule_groups": json.dumps( + { + category: [mapping.scap_rule_id for mapping in mappings] + for category, mappings in plan.rule_groups.items() + } + ), + "created_by": created_by, + }, + ) + db.commit() logger.info(f"Stored remediation plan: {plan.plan_id}") - + except Exception as e: logger.error(f"Error storing remediation plan: {e}") db.rollback() -async def _update_remediation_plan_status(db: Session, aegis_remediation_id: str, verification_report: dict): +def _update_remediation_plan_status( + db: Session, aegis_remediation_id: str, verification_report: dict +): """Update remediation plan status after verification""" try: # Determine status based on verification results success_rate = verification_report.get("remediation_success_rate", 0) - + if success_rate >= 100: status = "completed" elif success_rate >= 50: status = "partial" else: status = "failed" - - db.execute(text(""" + + db.execute( + text( + """ UPDATE remediation_plans SET status = :status, remediated_rules = :remediated_rules, completed_at = NOW() WHERE aegis_job_id = :aegis_job_id - """), { - "status": status, - "remediated_rules": verification_report.get("successfully_remediated", 0), - "aegis_job_id": aegis_remediation_id - }) - + """ + ), + { + "status": status, + "remediated_rules": verification_report.get("successfully_remediated", 0), + "aegis_job_id": aegis_remediation_id, + }, + ) + db.commit() - + except Exception as e: logger.error(f"Error updating remediation plan status: {e}") db.rollback() @@ -481,5 +528,5 @@ def _analyze_improvement(previous_scan_id: str, current_results: dict) -> dict: "current_compliance_score": current_results.get("compliance_score", 0), "rules_improved": 0, # Would calculate from comparison "rules_regressed": 0, - "net_improvement": True - } \ No newline at end of file + "net_improvement": True, + } diff --git a/backend/app/routes/scan_templates.py b/backend/app/routes/scan_templates.py index 691a85f4..7d230aa5 100644 --- a/backend/app/routes/scan_templates.py +++ b/backend/app/routes/scan_templates.py @@ -1,6 +1,7 @@ """ Scan Template Routes - Quick Scan Configuration """ + from fastapi import APIRouter, HTTPException, Depends, status from pydantic import BaseModel from typing import List, Optional @@ -28,8 +29,7 @@ class ScanTemplate(BaseModel): @router.get("/") async def list_scan_templates( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """List available scan templates""" # For now, return predefined templates @@ -44,40 +44,38 @@ async def list_scan_templates( "scope": "system", "isDefault": True, "estimatedDuration": "5-10 min", - "ruleCount": 120 + "ruleCount": 120, }, { "id": "security-audit", - "name": "Security Audit", + "name": "Security Audit", "description": "Comprehensive security configuration review", "contentId": 1, "profileId": "xccdf_org.ssgproject.content_profile_stig", "scope": "system", "isDefault": False, "estimatedDuration": "15-25 min", - "ruleCount": 340 + "ruleCount": 340, }, { "id": "vulnerability-scan", "name": "Vulnerability Check", - "description": "Scan for known security vulnerabilities", + "description": "Scan for known security vulnerabilities", "contentId": 1, "profileId": "xccdf_org.ssgproject.content_profile_cis", "scope": "system", "isDefault": False, "estimatedDuration": "10-15 min", - "ruleCount": 200 - } + "ruleCount": 200, + }, ] - + return {"templates": templates} @router.get("/host/{host_id}") async def get_host_scan_templates( - host_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + host_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get scan templates available for a specific host""" # For now, return the same system templates @@ -90,30 +88,25 @@ async def get_host_scan_templates( async def create_scan_template( template: ScanTemplate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Create a new scan template""" # Basic validation if not template.name or not template.profileId: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Template name and profile ID are required" + detail="Template name and profile ID are required", ) - + # In full implementation, would save to database # For now, just return success - return { - "message": "Scan template created successfully", - "template_id": template.id - } + return {"message": "Scan template created successfully", "template_id": template.id} @router.delete("/{template_id}") async def delete_scan_template( - template_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + template_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Delete a scan template""" # In full implementation, would delete from database - return {"message": f"Scan template {template_id} deleted successfully"} \ No newline at end of file + return {"message": f"Scan template {template_id} deleted successfully"} diff --git a/backend/app/routes/scans.py b/backend/app/routes/scans.py index 89a70c22..ce86dd3b 100644 --- a/backend/app/routes/scans.py +++ b/backend/app/routes/scans.py @@ -2,6 +2,7 @@ SCAP Scanning API Routes Handles scan job creation, monitoring, and results """ + import uuid import json import asyncio @@ -33,41 +34,42 @@ error_service = ErrorClassificationService() sanitization_service = get_error_sanitization_service() + def sanitize_http_error( - request: Request, - current_user: dict, - exception: Exception, + request: Request, + current_user: dict, + exception: Exception, fallback_message: str, - status_code: int = 500 + status_code: int = 500, ) -> HTTPException: """Helper to sanitize HTTP errors and prevent information disclosure""" try: # Get client information client_ip = request.client.host if request.client else "unknown" user_id = current_user.get("sub") if current_user else None - + # Classify the error internally classified_error = asyncio.create_task( error_service.classify_error(exception, {"http_endpoint": str(request.url.path)}) ) - + # For synchronous context, use a generic approach sanitized_error = sanitization_service.sanitize_error( { - 'error_code': 'HTTP_ERROR', - 'category': 'execution', - 'severity': 'error', - 'message': str(exception), - 'technical_details': {'original_error': str(exception)}, - 'user_guidance': fallback_message, - 'can_retry': False + "error_code": "HTTP_ERROR", + "category": "execution", + "severity": "error", + "message": str(exception), + "technical_details": {"original_error": str(exception)}, + "user_guidance": fallback_message, + "can_retry": False, }, user_id=user_id, - source_ip=client_ip + source_ip=client_ip, ) - + return HTTPException(status_code=status_code, detail=sanitized_error.message) - + except Exception as sanitization_error: # Fallback if sanitization fails logger.error(f"Error sanitization failed: {sanitization_error}") @@ -150,29 +152,39 @@ async def validate_scan_configuration( validation_request: ValidationRequest, request: Request, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ) -> ValidationResultResponse: """Pre-flight validation for scan configuration""" try: logger.info(f"Pre-flight validation requested for host {validation_request.host_id}") - + # Get host details - host_result = db.execute(text(""" + host_result = db.execute( + text( + """ SELECT id, display_name, hostname, port, username, auth_method, encrypted_credentials FROM hosts WHERE id = :id AND is_active = true - """), {"id": validation_request.host_id}).fetchone() - + """ + ), + {"id": validation_request.host_id}, + ).fetchone() + if not host_result: raise HTTPException(status_code=404, detail="Host not found or inactive") - + # Get SCAP content details - content_result = db.execute(text(""" + content_result = db.execute( + text( + """ SELECT id, name, file_path, profiles FROM scap_content WHERE id = :id - """), {"id": validation_request.content_id}).fetchone() - + """ + ), + {"id": validation_request.content_id}, + ).fetchone() + if not content_result: raise HTTPException(status_code=404, detail="SCAP content not found") - + # Validate profile exists if content_result.profiles: try: @@ -182,23 +194,23 @@ async def validate_scan_configuration( raise HTTPException(status_code=400, detail="Profile not found in SCAP content") except: raise HTTPException(status_code=400, detail="Invalid SCAP content profiles") - + # Resolve credentials try: from ..services.auth_service import get_auth_service + auth_service = get_auth_service(db) - + use_default = host_result.auth_method in ["default", "system_default"] target_id = None if use_default else host_result.id - + credential_data = auth_service.resolve_credential( - target_id=target_id, - use_default=use_default + target_id=target_id, use_default=use_default ) - + if not credential_data: raise HTTPException(status_code=400, detail="No credentials available for host") - + # Extract credential value based on auth method if credential_data.auth_method.value == "password": credential_value = credential_data.password @@ -206,15 +218,17 @@ async def validate_scan_configuration( credential_value = credential_data.private_key else: credential_value = credential_data.password or "" - + except Exception as e: logger.error(f"Credential resolution failed for validation: {e}") raise sanitize_http_error( - request, current_user, e, - "Unable to resolve authentication credentials for target host", - 400 + request, + current_user, + e, + "Unable to resolve authentication credentials for target host", + 400, ) - + # Get client information for security audit client_ip = request.client.host if request.client else "unknown" user_id = current_user.get("sub") if current_user else None @@ -229,48 +243,50 @@ async def validate_scan_configuration( auth_method=credential_data.auth_method.value, credential=credential_value, user_id=user_id, - source_ip=client_ip + source_ip=client_ip, + ) + + logger.info( + f"Validation completed: can_proceed={internal_result.can_proceed}, " + f"errors={len(internal_result.errors)}, warnings={len(internal_result.warnings)}" ) - - logger.info(f"Validation completed: can_proceed={internal_result.can_proceed}, " - f"errors={len(internal_result.errors)}, warnings={len(internal_result.warnings)}") - + # Convert to sanitized response using Security Fix 5 system info sanitization sanitized_result = error_service.get_sanitized_validation_result( internal_result, user_id=user_id, source_ip=client_ip, user_role=user_role, - is_admin=is_admin + is_admin=is_admin, ) - + return sanitized_result - + except HTTPException: raise except Exception as e: # Log full technical details server-side logger.error(f"Validation error: {e}", exc_info=True) - + # Create sanitized error for user sanitization_service = get_error_sanitization_service() - classified_error = await error_service.classify_error(e, { - "operation": "scan_validation", - "host_id": validation_request.host_id, - "content_id": validation_request.content_id - }) - + classified_error = await error_service.classify_error( + e, + { + "operation": "scan_validation", + "host_id": validation_request.host_id, + "content_id": validation_request.content_id, + }, + ) + sanitized_error = sanitization_service.sanitize_error( classified_error.dict(), user_id=current_user.get("sub") if current_user else None, - source_ip=request.client.host if request.client else "unknown" + source_ip=request.client.host if request.client else "unknown", ) - + # Return generic error message to prevent information disclosure - raise HTTPException( - status_code=500, - detail=f"Validation failed: {sanitized_error.message}" - ) + raise HTTPException(status_code=500, detail=f"Validation failed: {sanitized_error.message}") @router.post("/hosts/{host_id}/quick-scan", response_model=QuickScanResponse) @@ -279,15 +295,17 @@ async def quick_scan( quick_scan_request: QuickScanRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ) -> QuickScanResponse: """Start scan with intelligent defaults - Zero to Scan in 3 Clicks""" try: - logger.info(f"Quick scan requested for host {host_id} with template {quick_scan_request.template_id}") - + logger.info( + f"Quick scan requested for host {host_id} with template {quick_scan_request.template_id}" + ) + # Initialize intelligence service intelligence_service = ScanIntelligenceService(db) - + # Auto-detect profile if not specified suggested_profile = None if quick_scan_request.template_id == "auto": @@ -298,27 +316,37 @@ async def quick_scan( # Use specified template - for now, map to default content template_id = quick_scan_request.template_id content_id = 1 # Default SCAP content - + # Still get suggestion for response metadata suggested_profile = await intelligence_service.suggest_scan_profile(host_id) - + # Get host details for validation - host_result = db.execute(text(""" + host_result = db.execute( + text( + """ SELECT id, display_name, hostname, port, username, auth_method, encrypted_credentials FROM hosts WHERE id = :id AND is_active = true - """), {"id": host_id}).fetchone() - + """ + ), + {"id": host_id}, + ).fetchone() + if not host_result: raise HTTPException(status_code=404, detail="Host not found or inactive") - + # Get SCAP content details - content_result = db.execute(text(""" + content_result = db.execute( + text( + """ SELECT id, name, file_path, profiles FROM scap_content WHERE id = :id - """), {"id": content_id}).fetchone() - + """ + ), + {"id": content_id}, + ).fetchone() + if not content_result: raise HTTPException(status_code=404, detail="SCAP content not found") - + # Validate profile exists in content profiles = [] if content_result.profiles: @@ -329,75 +357,85 @@ async def quick_scan( # Fall back to first available profile if profile_ids: template_id = profile_ids[0] - logger.warning(f"Requested profile not found, using fallback: {template_id}") + logger.warning( + f"Requested profile not found, using fallback: {template_id}" + ) else: - raise HTTPException(status_code=400, detail="No profiles available in SCAP content") + raise HTTPException( + status_code=400, detail="No profiles available in SCAP content" + ) except: raise HTTPException(status_code=400, detail="Invalid SCAP content profiles") - + # Generate scan name scan_name = quick_scan_request.name if not scan_name: profile_name = suggested_profile.name if suggested_profile else "Quick Scan" scan_name = f"{profile_name} - {host_result.display_name or host_result.hostname}" - + # Create scan record with UUID primary key scan_id = str(uuid.uuid4()) - + # Pre-flight validation (async, non-blocking for optimistic UI) validation_task = None try: from ..services.auth_service import get_auth_service + auth_service = get_auth_service(db) - + use_default = host_result.auth_method in ["default", "system_default"] target_id = None if use_default else host_result.id - + credential_data = auth_service.resolve_credential( - target_id=target_id, - use_default=use_default + target_id=target_id, use_default=use_default ) - + if credential_data: # Queue async validation validation_task = background_tasks.add_task( - self._async_validation_check, - scan_id, host_result, credential_data + self._async_validation_check, scan_id, host_result, credential_data ) except Exception as e: logger.warning(f"Pre-flight validation setup failed: {e}") - + # Create scan immediately (optimistic UI) - result = db.execute(text(""" + result = db.execute( + text( + """ INSERT INTO scans (id, name, host_id, content_id, profile_id, status, progress, scan_options, started_by, started_at, remediation_requested, verification_scan) VALUES (:id, :name, :host_id, :content_id, :profile_id, :status, :progress, :scan_options, :started_by, :started_at, :remediation_requested, :verification_scan) RETURNING id - """), { - "id": scan_id, - "name": scan_name, - "host_id": host_id, - "content_id": content_id, - "profile_id": template_id, - "status": "pending", - "progress": 0, - "scan_options": json.dumps({ - "quick_scan": True, - "template_id": quick_scan_request.template_id, - "priority": quick_scan_request.priority, - "email_notify": quick_scan_request.email_notify - }), - "started_by": current_user["id"], - "started_at": datetime.utcnow(), - "remediation_requested": False, - "verification_scan": False - }) - + """ + ), + { + "id": scan_id, + "name": scan_name, + "host_id": host_id, + "content_id": content_id, + "profile_id": template_id, + "status": "pending", + "progress": 0, + "scan_options": json.dumps( + { + "quick_scan": True, + "template_id": quick_scan_request.template_id, + "priority": quick_scan_request.priority, + "email_notify": quick_scan_request.email_notify, + } + ), + "started_by": current_user["id"], + "started_at": datetime.utcnow(), + "remediation_requested": False, + "verification_scan": False, + }, + ) + # Commit the scan record db.commit() - + # Start scan as background task background_tasks.add_task( execute_scan_task, @@ -407,18 +445,15 @@ async def quick_scan( "port": host_result.port, "username": host_result.username, "auth_method": host_result.auth_method, - "encrypted_credentials": host_result.encrypted_credentials + "encrypted_credentials": host_result.encrypted_credentials, }, content_path=content_result.file_path, profile_id=template_id, - scan_options={ - "quick_scan": True, - "priority": quick_scan_request.priority - } + scan_options={"quick_scan": True, "priority": quick_scan_request.priority}, ) - + logger.info(f"Quick scan created and started: {scan_id}") - + # Calculate estimated completion estimated_time = None if suggested_profile: @@ -432,12 +467,13 @@ async def quick_scan( estimated_time = datetime.utcnow().timestamp() + (avg_minutes * 60) except: pass - + return QuickScanResponse( id=scan_id, message="Scan created and started successfully", status="pending", - suggested_profile=suggested_profile or ProfileSuggestion( + suggested_profile=suggested_profile + or ProfileSuggestion( profile_id=template_id, content_id=content_id, name="Quick Scan", @@ -445,11 +481,11 @@ async def quick_scan( reasoning=["Manual template selection"], estimated_duration="10-15 min", rule_count=150, - priority=suggested_profile.priority if suggested_profile else "normal" + priority=suggested_profile.priority if suggested_profile else "normal", ), - estimated_completion=estimated_time + estimated_completion=estimated_time, ) - + except HTTPException: raise except Exception as e: @@ -458,19 +494,21 @@ async def quick_scan( try: classified_error = await error_service.classify_error(e, {"operation": "quick_scan"}) raise HTTPException( - status_code=500, + status_code=500, detail={ "message": classified_error.message, "category": classified_error.category.value, "user_guidance": classified_error.user_guidance, "can_retry": classified_error.can_retry, - "error_code": classified_error.error_code - } + "error_code": classified_error.error_code, + }, ) except Exception as fallback_error: # Fallback to generic error if classification fails logger.error(f"Quick scan creation failed with classification error: {fallback_error}") - raise HTTPException(status_code=500, detail="Failed to create scan due to system configuration error") + raise HTTPException( + status_code=500, detail="Failed to create scan due to system configuration error" + ) async def _async_validation_check(scan_id: str, host_result, credential_data): @@ -485,21 +523,21 @@ async def create_bulk_scan( bulk_scan_request: BulkScanRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ) -> BulkScanResponse: """Create and start bulk scan session for multiple hosts""" try: logger.info(f"Bulk scan requested for {len(bulk_scan_request.host_ids)} hosts") - + if not bulk_scan_request.host_ids: raise HTTPException(status_code=400, detail="No host IDs provided") - + if len(bulk_scan_request.host_ids) > 100: raise HTTPException(status_code=400, detail="Maximum 100 hosts per bulk scan") - + # Initialize orchestrator orchestrator = BulkScanOrchestrator(db) - + # Create bulk scan session session = await orchestrator.create_bulk_scan_session( host_ids=bulk_scan_request.host_ids, @@ -507,22 +545,24 @@ async def create_bulk_scan( name_prefix=bulk_scan_request.name_prefix, priority=bulk_scan_request.priority, user_id=current_user["id"], - stagger_delay=bulk_scan_request.stagger_delay + stagger_delay=bulk_scan_request.stagger_delay, ) - + # Start the bulk scan session start_result = await orchestrator.start_bulk_scan_session(session.id) - + logger.info(f"Bulk scan session created and started: {session.id}") - + return BulkScanResponse( session_id=session.id, message=f"Bulk scan session created for {session.total_hosts} hosts", total_hosts=session.total_hosts, - estimated_completion=session.estimated_completion.timestamp() if session.estimated_completion else 0, - scan_ids=session.scan_ids or [] + estimated_completion=( + session.estimated_completion.timestamp() if session.estimated_completion else 0 + ), + scan_ids=session.scan_ids or [], ) - + except HTTPException: raise except Exception as e: @@ -532,16 +572,14 @@ async def create_bulk_scan( @router.get("/bulk-scan/{session_id}/progress") async def get_bulk_scan_progress( - session_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + session_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get real-time progress of a bulk scan session""" try: orchestrator = BulkScanOrchestrator(db) progress = await orchestrator.get_bulk_scan_progress(session_id) return progress - + except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: @@ -551,23 +589,28 @@ async def get_bulk_scan_progress( @router.post("/bulk-scan/{session_id}/cancel") async def cancel_bulk_scan( - session_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + session_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Cancel a running bulk scan session""" try: # Update session status to cancelled - result = db.execute(text(""" + result = db.execute( + text( + """ UPDATE scan_sessions SET status = 'cancelled' WHERE id = :session_id - """), {"session_id": session_id}) - + """ + ), + {"session_id": session_id}, + ) + if result.rowcount == 0: raise HTTPException(status_code=404, detail="Bulk scan session not found") - + # Cancel individual scans that are still pending - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET status = 'cancelled', error_message = 'Cancelled by user' WHERE id IN ( SELECT unnest(ARRAY( @@ -575,13 +618,16 @@ async def cancel_bulk_scan( FROM scan_sessions WHERE id = :session_id )) ) AND status IN ('pending', 'running') - """), {"session_id": session_id}) - + """ + ), + {"session_id": session_id}, + ) + db.commit() - + logger.info(f"Bulk scan session cancelled: {session_id}") return {"message": "Bulk scan session cancelled successfully"} - + except HTTPException: raise except Exception as e: @@ -595,64 +641,73 @@ async def list_scan_sessions( limit: int = 20, offset: int = 0, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """List scan sessions for monitoring and management""" try: # Build query conditions where_conditions = [] params = {"limit": limit, "offset": offset} - + if status: where_conditions.append("status = :status") params["status"] = status - + # Add user filtering if not admin if current_user.get("role") not in ["SUPER_ADMIN", "SECURITY_ADMIN"]: where_conditions.append("created_by = :user_id") params["user_id"] = current_user["id"] - + where_clause = "WHERE " + " AND ".join(where_conditions) if where_conditions else "" - + # Get sessions - result = db.execute(text(f""" + result = db.execute( + text( + f""" SELECT id, name, total_hosts, completed_hosts, failed_hosts, running_hosts, status, created_by, created_at, started_at, completed_at, estimated_completion FROM scan_sessions {where_clause} ORDER BY created_at DESC LIMIT :limit OFFSET :offset - """), params) - + """ + ), + params, + ) + sessions = [] for row in result: - sessions.append({ - "session_id": row.id, - "name": row.name, - "total_hosts": row.total_hosts, - "completed_hosts": row.completed_hosts, - "failed_hosts": row.failed_hosts, - "running_hosts": row.running_hosts, - "status": row.status, - "created_by": row.created_by, - "created_at": row.created_at.isoformat() if row.created_at else None, - "started_at": row.started_at.isoformat() if row.started_at else None, - "completed_at": row.completed_at.isoformat() if row.completed_at else None, - "estimated_completion": row.estimated_completion.isoformat() if row.estimated_completion else None - }) - + sessions.append( + { + "session_id": row.id, + "name": row.name, + "total_hosts": row.total_hosts, + "completed_hosts": row.completed_hosts, + "failed_hosts": row.failed_hosts, + "running_hosts": row.running_hosts, + "status": row.status, + "created_by": row.created_by, + "created_at": row.created_at.isoformat() if row.created_at else None, + "started_at": row.started_at.isoformat() if row.started_at else None, + "completed_at": row.completed_at.isoformat() if row.completed_at else None, + "estimated_completion": ( + row.estimated_completion.isoformat() if row.estimated_completion else None + ), + } + ) + # Get total count - count_result = db.execute(text(f""" + count_result = db.execute( + text( + f""" SELECT COUNT(*) as total FROM scan_sessions {where_clause} - """), params).fetchone() - - return { - "sessions": sessions, - "total": count_result.total, - "limit": limit, - "offset": offset - } - + """ + ), + params, + ).fetchone() + + return {"sessions": sessions, "total": count_result.total, "limit": limit, "offset": offset} + except Exception as e: logger.error(f"Error listing scan sessions: {e}") raise HTTPException(status_code=500, detail="Failed to list scan sessions") @@ -660,75 +715,82 @@ async def list_scan_sessions( @router.post("/{scan_id}/recover") async def recover_scan( - scan_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + scan_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Attempt to recover a failed scan with intelligent retry""" try: # Get failed scan details - scan_result = db.execute(text(""" + scan_result = db.execute( + text( + """ SELECT s.id, s.name, s.host_id, s.content_id, s.profile_id, s.status, s.error_message, h.hostname, h.port, h.username, h.auth_method FROM scans s JOIN hosts h ON s.host_id = h.id WHERE s.id = :scan_id AND s.status = 'failed' - """), {"scan_id": scan_id}).fetchone() - + """ + ), + {"scan_id": scan_id}, + ).fetchone() + if not scan_result: raise HTTPException(status_code=404, detail="Failed scan not found") - + # Classify the original error to determine recovery strategy original_error = Exception(scan_result.error_message or "Unknown error") classified_error = await error_service.classify_error( - original_error, - {"scan_id": scan_id, "hostname": scan_result.hostname} + original_error, {"scan_id": scan_id, "hostname": scan_result.hostname} ) - + # Determine if retry is possible if not classified_error.can_retry: return { "can_recover": False, "message": "Scan cannot be automatically recovered", "error_classification": classified_error.dict(), - "recommended_actions": classified_error.user_guidance + "recommended_actions": classified_error.user_guidance, } - + # Calculate retry delay retry_delay = classified_error.retry_after or 60 - + # Create recovery scan recovery_scan_id = str(uuid.uuid4()) - db.execute(text(""" + db.execute( + text( + """ INSERT INTO scans (id, name, host_id, content_id, profile_id, status, progress, started_by, started_at, scan_options) VALUES (:id, :name, :host_id, :content_id, :profile_id, :status, :progress, :started_by, :started_at, :scan_options) - """), { - "id": recovery_scan_id, - "name": f"Recovery: {scan_result.name}", - "host_id": scan_result.host_id, - "content_id": scan_result.content_id, - "profile_id": scan_result.profile_id, - "status": "pending", - "progress": 0, - "started_by": current_user["id"], - "started_at": datetime.utcnow(), - "scan_options": json.dumps({"recovery_scan": True, "original_scan_id": scan_id}) - }) + """ + ), + { + "id": recovery_scan_id, + "name": f"Recovery: {scan_result.name}", + "host_id": scan_result.host_id, + "content_id": scan_result.content_id, + "profile_id": scan_result.profile_id, + "status": "pending", + "progress": 0, + "started_by": current_user["id"], + "started_at": datetime.utcnow(), + "scan_options": json.dumps({"recovery_scan": True, "original_scan_id": scan_id}), + }, + ) db.commit() - + logger.info(f"Recovery scan created: {recovery_scan_id} for failed scan {scan_id}") - + return { "can_recover": True, "recovery_scan_id": recovery_scan_id, "message": f"Recovery scan created and will start in {retry_delay} seconds", "error_classification": classified_error.dict(), - "estimated_retry_time": (datetime.utcnow().timestamp() + retry_delay) + "estimated_retry_time": (datetime.utcnow().timestamp() + retry_delay), } - + except HTTPException: raise except Exception as e: @@ -742,40 +804,45 @@ async def apply_automated_fix( fix_request: AutomatedFixRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Apply an automated fix to a host""" try: # Get host details - host_result = db.execute(text(""" + host_result = db.execute( + text( + """ SELECT id, display_name, hostname, port, username, auth_method FROM hosts WHERE id = :id AND is_active = true - """), {"id": host_id}).fetchone() - + """ + ), + {"id": host_id}, + ).fetchone() + if not host_result: raise HTTPException(status_code=404, detail="Host not found or inactive") - + logger.info(f"Applying automated fix {fix_request.fix_id} to host {host_id}") - + # For now, return a mock response - in production this would execute the fix # This would integrate with the actual fix execution system - + # Mock execution time based on fix type estimated_time = 30 # Default 30 seconds if "install" in fix_request.fix_id.lower(): estimated_time = 120 elif "update" in fix_request.fix_id.lower(): estimated_time = 60 - + # Create a mock job ID for tracking job_id = str(uuid.uuid4()) - + # In production, this would: # 1. Queue the fix execution as a background task # 2. Track progress in database # 3. Execute commands via SSH # 4. Validate results if requested - + return { "job_id": job_id, "fix_id": fix_request.fix_id, @@ -783,9 +850,9 @@ async def apply_automated_fix( "status": "queued", "estimated_completion": (datetime.utcnow().timestamp() + estimated_time), "message": f"Automated fix {fix_request.fix_id} queued for execution", - "validate_after": fix_request.validate_after + "validate_after": fix_request.validate_after, } - + except HTTPException: raise except Exception as e: @@ -800,34 +867,29 @@ async def list_scans( limit: int = 50, offset: int = 0, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """List scans with optional filtering""" try: # Quick fix: Check if there are any scans at all scan_count_result = db.execute(text("SELECT COUNT(*) as count FROM scans")).fetchone() if scan_count_result.count == 0: - return { - "scans": [], - "total": 0, - "limit": limit, - "offset": offset - } - + return {"scans": [], "total": 0, "limit": limit, "offset": offset} + # Build query where_conditions = [] params = {"limit": limit, "offset": offset} - + if host_id: where_conditions.append("s.host_id = %(host_id)s") params["host_id"] = host_id - + if status: where_conditions.append("s.status = %(status)s") params["status"] = status - + where_clause = "WHERE " + " AND ".join(where_conditions) if where_conditions else "" - + query = """ SELECT s.id, s.name, s.host_id, s.content_id, s.profile_id, s.status, s.progress, s.started_at, s.completed_at, s.started_by, @@ -844,10 +906,12 @@ async def list_scans( {} ORDER BY s.started_at DESC LIMIT :limit OFFSET :offset - """.format(where_clause) - + """.format( + where_clause + ) + result = db.execute(text(query), params) - + scans = [] for row in result: scan_data = { @@ -861,7 +925,7 @@ async def list_scans( "ip_address": row.ip_address, "operating_system": row.operating_system, "status": row.host_status, - "last_check": row.last_check.isoformat() if row.last_check else None + "last_check": row.last_check.isoformat() if row.last_check else None, }, "content_id": row.content_id, "content_name": row.content_name, @@ -874,9 +938,9 @@ async def list_scans( "started_by": row.started_by, "error_message": row.error_message, "result_file": row.result_file, - "report_file": row.report_file + "report_file": row.report_file, } - + # Add results if available if row.total_rules is not None: scan_data["scan_result"] = { @@ -890,11 +954,11 @@ async def list_scans( "severity_high": row.severity_high, "severity_medium": row.severity_medium, "severity_low": row.severity_low, - "created_at": row.completed_at.isoformat() if row.completed_at else None + "created_at": row.completed_at.isoformat() if row.completed_at else None, } - + scans.append(scan_data) - + # Get total count count_query = """ SELECT COUNT(*) as total @@ -902,16 +966,13 @@ async def list_scans( LEFT JOIN hosts h ON s.host_id = h.id LEFT JOIN scap_content c ON s.content_id = c.id {} - """.format(where_clause) + """.format( + where_clause + ) total_result = db.execute(text(count_query), params).fetchone() - - return { - "scans": scans, - "total": total_result.total, - "limit": limit, - "offset": offset - } - + + return {"scans": scans, "total": total_result.total, "limit": limit, "offset": offset} + except Exception as e: logger.error(f"Error listing scans: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve scans") @@ -922,67 +983,84 @@ async def create_scan( scan_request: ScanRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Create and start a new SCAP scan""" try: # Validate host exists - host_result = db.execute(text(""" + host_result = db.execute( + text( + """ SELECT id, display_name, hostname, port, username, auth_method, encrypted_credentials FROM hosts WHERE id = :id AND is_active = true - """), {"id": scan_request.host_id}).fetchone() - + """ + ), + {"id": scan_request.host_id}, + ).fetchone() + if not host_result: raise HTTPException(status_code=404, detail="Host not found or inactive") - + # Validate SCAP content exists - content_result = db.execute(text(""" + content_result = db.execute( + text( + """ SELECT id, name, file_path, profiles FROM scap_content WHERE id = :id - """), {"id": scan_request.content_id}).fetchone() - + """ + ), + {"id": scan_request.content_id}, + ).fetchone() + if not content_result: raise HTTPException(status_code=404, detail="SCAP content not found") - + # Validate profile exists in content profiles = [] if content_result.profiles: try: import json + profiles = json.loads(content_result.profiles) profile_ids = [p.get("id") for p in profiles if p.get("id")] if scan_request.profile_id not in profile_ids: raise HTTPException(status_code=400, detail="Profile not found in SCAP content") except: raise HTTPException(status_code=400, detail="Invalid SCAP content profiles") - + # Create scan record with UUID primary key import json + scan_id = str(uuid.uuid4()) - result = db.execute(text(""" + result = db.execute( + text( + """ INSERT INTO scans (id, name, host_id, content_id, profile_id, status, progress, scan_options, started_by, started_at, remediation_requested, verification_scan) VALUES (:id, :name, :host_id, :content_id, :profile_id, :status, :progress, :scan_options, :started_by, :started_at, :remediation_requested, :verification_scan) RETURNING id - """), { - "id": scan_id, - "name": scan_request.name, - "host_id": scan_request.host_id, - "content_id": scan_request.content_id, - "profile_id": scan_request.profile_id, - "status": "pending", - "progress": 0, - "scan_options": json.dumps(scan_request.scan_options), - "started_by": current_user["id"], - "started_at": datetime.utcnow(), - "remediation_requested": False, - "verification_scan": False - }) - + """ + ), + { + "id": scan_id, + "name": scan_request.name, + "host_id": scan_request.host_id, + "content_id": scan_request.content_id, + "profile_id": scan_request.profile_id, + "status": "pending", + "progress": 0, + "scan_options": json.dumps(scan_request.scan_options), + "started_by": current_user["id"], + "started_at": datetime.utcnow(), + "remediation_requested": False, + "verification_scan": False, + }, + ) + # Commit the scan record db.commit() - + # Start scan as background task background_tasks.add_task( execute_scan_task, @@ -992,21 +1070,21 @@ async def create_scan( "port": host_result.port, "username": host_result.username, "auth_method": host_result.auth_method, - "encrypted_credentials": host_result.encrypted_credentials + "encrypted_credentials": host_result.encrypted_credentials, }, content_path=content_result.file_path, profile_id=scan_request.profile_id, - scan_options=scan_request.scan_options + scan_options=scan_request.scan_options, ) - + logger.info(f"Scan created and started: {scan_id}") - + return { "id": scan_id, "message": "Scan created and started successfully", - "status": "pending" + "status": "pending", } - + except HTTPException: raise except Exception as e: @@ -1015,14 +1093,14 @@ async def create_scan( try: classified_error = await error_service.classify_error(e, {"operation": "create_scan"}) raise HTTPException( - status_code=500, + status_code=500, detail={ "message": classified_error.message, "category": classified_error.category.value, "user_guidance": classified_error.user_guidance, "can_retry": classified_error.can_retry, - "error_code": classified_error.error_code - } + "error_code": classified_error.error_code, + }, ) except: # Fallback to generic error if classification fails @@ -1031,13 +1109,13 @@ async def create_scan( @router.get("/{scan_id}") async def get_scan( - scan_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + scan_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get scan details""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT s.id, s.name, s.host_id, s.content_id, s.profile_id, s.status, s.progress, s.result_file, s.report_file, s.error_message, s.scan_options, s.started_at, s.completed_at, s.started_by, @@ -1048,19 +1126,23 @@ async def get_scan( JOIN hosts h ON s.host_id = h.id JOIN scap_content c ON s.content_id = c.id WHERE s.id = :id - """), {"id": scan_id}).fetchone() - + """ + ), + {"id": scan_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="Scan not found") - + scan_options = {} if result.scan_options: try: import json + scan_options = json.loads(result.scan_options) except: pass - + scan_data = { "id": result.id, "name": result.name, @@ -1080,18 +1162,23 @@ async def get_scan( "started_at": result.started_at.isoformat() if result.started_at else None, "completed_at": result.completed_at.isoformat() if result.completed_at else None, "started_by": result.started_by, - "celery_task_id": result.celery_task_id + "celery_task_id": result.celery_task_id, } - + # Add results summary if scan is completed if result.status == "completed": - results = db.execute(text(""" + results = db.execute( + text( + """ SELECT total_rules, passed_rules, failed_rules, error_rules, unknown_rules, not_applicable_rules, score, severity_high, severity_medium, severity_low FROM scan_results WHERE scan_id = :scan_id - """), {"scan_id": scan_id}).fetchone() - + """ + ), + {"scan_id": scan_id}, + ).fetchone() + if results: scan_data["results"] = { "total_rules": results.total_rules, @@ -1103,11 +1190,11 @@ async def get_scan( "score": results.score, "severity_high": results.severity_high, "severity_medium": results.severity_medium, - "severity_low": results.severity_low + "severity_low": results.severity_low, } - + return scan_data - + except HTTPException: raise except Exception as e: @@ -1120,45 +1207,50 @@ async def update_scan( scan_id: str, scan_update: ScanUpdate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Update scan status (internal use)""" try: # Check if scan exists - existing = db.execute(text(""" + existing = db.execute( + text( + """ SELECT id FROM scans WHERE id = :id - """), {"id": scan_id}).fetchone() - + """ + ), + {"id": scan_id}, + ).fetchone() + if not existing: raise HTTPException(status_code=404, detail="Scan not found") - + # Build update query updates = [] params = {"id": scan_id} - + if scan_update.status is not None: updates.append("status = :status") params["status"] = scan_update.status - + if scan_update.progress is not None: updates.append("progress = :progress") params["progress"] = scan_update.progress - + if scan_update.error_message is not None: updates.append("error_message = :error_message") params["error_message"] = scan_update.error_message - + if scan_update.status == "completed": updates.append("completed_at = :completed_at") params["completed_at"] = datetime.utcnow() - + if updates: - query = "UPDATE scans SET {} WHERE id = :id".format(', '.join(updates)) + query = "UPDATE scans SET {} WHERE id = :id".format(", ".join(updates)) db.execute(text(query), params) db.commit() - + return {"message": "Scan updated successfully"} - + except HTTPException: raise except Exception as e: @@ -1168,51 +1260,62 @@ async def update_scan( @router.delete("/{scan_id}") async def delete_scan( - scan_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + scan_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Delete scan and its results""" try: # Check if scan exists and get status - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT status, result_file, report_file FROM scans WHERE id = :id - """), {"id": scan_id}).fetchone() - + """ + ), + {"id": scan_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="Scan not found") - + # Don't allow deletion of running scans if result.status in ["pending", "running"]: - raise HTTPException( - status_code=409, - detail="Cannot delete running scan" - ) - + raise HTTPException(status_code=409, detail="Cannot delete running scan") + # Delete result files import os + for file_path in [result.result_file, result.report_file]: if file_path and os.path.exists(file_path): try: os.unlink(file_path) except Exception as e: logger.warning(f"Failed to delete file {file_path}: {e}") - + # Delete scan results first (foreign key constraint) - db.execute(text(""" + db.execute( + text( + """ DELETE FROM scan_results WHERE scan_id = :scan_id - """), {"scan_id": scan_id}) - + """ + ), + {"scan_id": scan_id}, + ) + # Delete scan record - db.execute(text(""" + db.execute( + text( + """ DELETE FROM scans WHERE id = :id - """), {"id": scan_id}) - + """ + ), + {"id": scan_id}, + ) + db.commit() - + logger.info(f"Scan deleted: {scan_id}") return {"message": "Scan deleted successfully"} - + except HTTPException: raise except Exception as e: @@ -1222,49 +1325,54 @@ async def delete_scan( @router.post("/{scan_id}/stop") async def stop_scan( - scan_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + scan_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Stop a running scan""" try: # Check if scan exists and is running - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT status, celery_task_id FROM scans WHERE id = :id - """), {"id": scan_id}).fetchone() - + """ + ), + {"id": scan_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="Scan not found") - + if result.status not in ["pending", "running"]: raise HTTPException( - status_code=400, - detail=f"Cannot stop scan with status: {result.status}" + status_code=400, detail=f"Cannot stop scan with status: {result.status}" ) - + # Try to revoke Celery task if available if result.celery_task_id: try: from celery import current_app + current_app.control.revoke(result.celery_task_id, terminate=True) except Exception as e: logger.warning(f"Failed to revoke Celery task: {e}") - + # Update scan status - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET status = 'stopped', completed_at = :completed_at, error_message = 'Scan stopped by user' WHERE id = :id - """), { - "id": scan_id, - "completed_at": datetime.utcnow() - }) + """ + ), + {"id": scan_id, "completed_at": datetime.utcnow()}, + ) db.commit() - + logger.info(f"Scan stopped: {scan_id}") return {"message": "Scan stopped successfully"} - + except HTTPException: raise except Exception as e: @@ -1274,33 +1382,36 @@ async def stop_scan( @router.get("/{scan_id}/report/html") async def get_scan_html_report( - scan_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + scan_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Download scan HTML report""" try: # Get scan details - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT report_file FROM scans WHERE id = :id - """), {"id": scan_id}).fetchone() - + """ + ), + {"id": scan_id}, + ).fetchone() + if not result or not result.report_file: raise HTTPException(status_code=404, detail="Report not found") - + # Check if file exists import os + if not os.path.exists(result.report_file): raise HTTPException(status_code=404, detail="Report file not found") - + # Return file from fastapi.responses import FileResponse + return FileResponse( - result.report_file, - media_type="text/html", - filename=f"scan_{scan_id}_report.html" + result.report_file, media_type="text/html", filename=f"scan_{scan_id}_report.html" ) - + except HTTPException: raise except Exception as e: @@ -1310,82 +1421,92 @@ async def get_scan_html_report( @router.get("/{scan_id}/report/json") async def get_scan_json_report( - scan_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + scan_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Export scan results as JSON""" try: # Get full scan details with results scan_data = await get_scan(scan_id, db, current_user) - + # Add enhanced rule results with remediation if available if scan_data.get("status") == "completed" and scan_data.get("result_file"): try: # Get the SCAP content file path for remediation extraction content_file = None - content_result = db.execute(text(""" + content_result = db.execute( + text( + """ SELECT file_path FROM scap_content WHERE id = :content_id - """), {"content_id": scan_data.get("content_id")}).fetchone() - + """ + ), + {"content_id": scan_data.get("content_id")}, + ).fetchone() + if content_result: content_file = content_result.file_path - + # Temporarily disable enhanced parsing for performance (was taking 40+ seconds) # TODO: Implement caching or optimize the parsing logic enhanced_parsing_enabled = False - + if enhanced_parsing_enabled: # Use enhanced SCAP scanner parsing from ..services.scap_scanner import SCAPScanner + scanner = SCAPScanner() - enhanced_results = scanner._parse_scan_results(scan_data["result_file"], content_file) + enhanced_results = scanner._parse_scan_results( + scan_data["result_file"], content_file + ) else: enhanced_results = {} - + # Add enhanced rule details with remediation if "rule_details" in enhanced_results and enhanced_results["rule_details"]: - scan_data['rule_results'] = enhanced_results["rule_details"] - logger.info(f"Added {len(enhanced_results['rule_details'])} enhanced rules with remediation") + scan_data["rule_results"] = enhanced_results["rule_details"] + logger.info( + f"Added {len(enhanced_results['rule_details'])} enhanced rules with remediation" + ) else: # Fallback to basic parsing for backward compatibility import xml.etree.ElementTree as ET import os - + if os.path.exists(scan_data["result_file"]): tree = ET.parse(scan_data["result_file"]) root = tree.getroot() - + # Extract basic rule results - namespaces = {'xccdf': 'http://checklists.nist.gov/xccdf/1.2'} + namespaces = {"xccdf": "http://checklists.nist.gov/xccdf/1.2"} rule_results = [] - - for rule_result in root.findall('.//xccdf:rule-result', namespaces): - rule_id = rule_result.get('idref', '') - result_elem = rule_result.find('xccdf:result', namespaces) - + + for rule_result in root.findall(".//xccdf:rule-result", namespaces): + rule_id = rule_result.get("idref", "") + result_elem = rule_result.find("xccdf:result", namespaces) + if result_elem is not None: - rule_results.append({ - 'rule_id': rule_id, - 'result': result_elem.text, - 'severity': rule_result.get('severity', 'unknown'), - 'title': '', - 'description': '', - 'rationale': '', - 'remediation': {}, - 'references': [] - }) - - scan_data['rule_results'] = rule_results + rule_results.append( + { + "rule_id": rule_id, + "result": result_elem.text, + "severity": rule_result.get("severity", "unknown"), + "title": "", + "description": "", + "rationale": "", + "remediation": {}, + "references": [], + } + ) + + scan_data["rule_results"] = rule_results logger.info(f"Added {len(rule_results)} basic rules (fallback mode)") - + except Exception as e: logger.error(f"Error extracting enhanced rule data: {e}") # Maintain backward compatibility - don't break if enhancement fails - scan_data['rule_results'] = [] - + scan_data["rule_results"] = [] + return scan_data - + except HTTPException: raise except Exception as e: @@ -1393,68 +1514,61 @@ async def get_scan_json_report( raise HTTPException(status_code=500, detail="Failed to generate JSON report") -@router.get("/{scan_id}/report/csv") +@router.get("/{scan_id}/report/csv") async def get_scan_csv_report( - scan_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + scan_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Export scan results as CSV""" try: # Get scan data scan_data = await get_scan_json_report(scan_id, db, current_user) - + # Create CSV content import csv import io - + output = io.StringIO() writer = csv.writer(output) - + # Write headers - writer.writerow(['Scan Information']) - writer.writerow(['ID', scan_data.get('id')]) - writer.writerow(['Name', scan_data.get('name')]) - writer.writerow(['Host', scan_data.get('host_name')]) - writer.writerow(['Status', scan_data.get('status')]) - writer.writerow(['Score', scan_data.get('results', {}).get('score', 'N/A')]) + writer.writerow(["Scan Information"]) + writer.writerow(["ID", scan_data.get("id")]) + writer.writerow(["Name", scan_data.get("name")]) + writer.writerow(["Host", scan_data.get("host_name")]) + writer.writerow(["Status", scan_data.get("status")]) + writer.writerow(["Score", scan_data.get("results", {}).get("score", "N/A")]) writer.writerow([]) - + # Write summary - writer.writerow(['Summary Statistics']) - writer.writerow(['Metric', 'Value']) - if scan_data.get('results'): - results = scan_data['results'] - writer.writerow(['Total Rules', results.get('total_rules')]) - writer.writerow(['Passed', results.get('passed_rules')]) - writer.writerow(['Failed', results.get('failed_rules')]) - writer.writerow(['Errors', results.get('error_rules')]) - writer.writerow(['High Severity', results.get('severity_high')]) - writer.writerow(['Medium Severity', results.get('severity_medium')]) - writer.writerow(['Low Severity', results.get('severity_low')]) + writer.writerow(["Summary Statistics"]) + writer.writerow(["Metric", "Value"]) + if scan_data.get("results"): + results = scan_data["results"] + writer.writerow(["Total Rules", results.get("total_rules")]) + writer.writerow(["Passed", results.get("passed_rules")]) + writer.writerow(["Failed", results.get("failed_rules")]) + writer.writerow(["Errors", results.get("error_rules")]) + writer.writerow(["High Severity", results.get("severity_high")]) + writer.writerow(["Medium Severity", results.get("severity_medium")]) + writer.writerow(["Low Severity", results.get("severity_low")]) writer.writerow([]) - + # Write rule results if available - if 'rule_results' in scan_data: - writer.writerow(['Rule Results']) - writer.writerow(['Rule ID', 'Result', 'Severity']) - for rule in scan_data['rule_results']: - writer.writerow([ - rule.get('rule_id'), - rule.get('result'), - rule.get('severity') - ]) - + if "rule_results" in scan_data: + writer.writerow(["Rule Results"]) + writer.writerow(["Rule ID", "Result", "Severity"]) + for rule in scan_data["rule_results"]: + writer.writerow([rule.get("rule_id"), rule.get("result"), rule.get("severity")]) + # Return CSV from fastapi.responses import Response + return Response( content=output.getvalue(), media_type="text/csv", - headers={ - "Content-Disposition": f"attachment; filename=scan_{scan_id}_report.csv" - } + headers={"Content-Disposition": f"attachment; filename=scan_{scan_id}_report.csv"}, ) - + except HTTPException: raise except Exception as e: @@ -1464,14 +1578,14 @@ async def get_scan_csv_report( @router.get("/{scan_id}/failed-rules") async def get_scan_failed_rules( - scan_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + scan_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get failed rules from a completed scan for AEGIS integration""" try: # Verify scan exists and is completed - scan_result = db.execute(text(""" + scan_result = db.execute( + text( + """ SELECT s.id, s.name, s.host_id, s.status, s.result_file, s.content_id, s.profile_id, h.hostname, h.ip_address, h.display_name as host_name, c.name as content_name, c.filename as content_filename, @@ -1481,15 +1595,24 @@ async def get_scan_failed_rules( JOIN scap_content c ON s.content_id = c.id LEFT JOIN scan_results sr ON sr.scan_id = s.id WHERE s.id = :scan_id - """), {"scan_id": scan_id}).fetchone() - + """ + ), + {"scan_id": scan_id}, + ).fetchone() + if not scan_result: raise HTTPException(status_code=404, detail="Scan not found") - + if scan_result.status != "completed": - raise HTTPException(status_code=400, detail=f"Scan not completed (status: {scan_result.status})") - - if not scan_result.result_file or not scan_result.failed_rules or scan_result.failed_rules == 0: + raise HTTPException( + status_code=400, detail=f"Scan not completed (status: {scan_result.status})" + ) + + if ( + not scan_result.result_file + or not scan_result.failed_rules + or scan_result.failed_rules == 0 + ): return { "scan_id": scan_id, "host_id": str(scan_result.host_id), @@ -1499,52 +1622,52 @@ async def get_scan_failed_rules( "total_rules": scan_result.total_rules or 0, "failed_rules_count": 0, "compliance_score": scan_result.score, - "failed_rules": [] + "failed_rules": [], } - + # Parse the SCAP result file to extract failed rules import xml.etree.ElementTree as ET import os - + failed_rules = [] if os.path.exists(scan_result.result_file): try: tree = ET.parse(scan_result.result_file) root = tree.getroot() - + # Extract failed rule results - namespaces = {'xccdf': 'http://checklists.nist.gov/xccdf/1.2'} - - for rule_result in root.findall('.//xccdf:rule-result', namespaces): - result_elem = rule_result.find('xccdf:result', namespaces) - - if result_elem is not None and result_elem.text == 'fail': - rule_id = rule_result.get('idref', '') - severity = rule_result.get('severity', 'unknown') - + namespaces = {"xccdf": "http://checklists.nist.gov/xccdf/1.2"} + + for rule_result in root.findall(".//xccdf:rule-result", namespaces): + result_elem = rule_result.find("xccdf:result", namespaces) + + if result_elem is not None and result_elem.text == "fail": + rule_id = rule_result.get("idref", "") + severity = rule_result.get("severity", "unknown") + # Extract additional metadata if available - check_elem = rule_result.find('xccdf:check', namespaces) + check_elem = rule_result.find("xccdf:check", namespaces) check_content_ref = "" if check_elem is not None: - content_ref = check_elem.find('xccdf:check-content-ref', namespaces) + content_ref = check_elem.find("xccdf:check-content-ref", namespaces) if content_ref is not None: - check_content_ref = content_ref.get('href', '') - + check_content_ref = content_ref.get("href", "") + failed_rule = { - 'rule_id': rule_id, - 'severity': severity, - 'result': 'fail', - 'check_content_ref': check_content_ref, - 'remediation_available': True # Assume remediation available for AEGIS + "rule_id": rule_id, + "severity": severity, + "result": "fail", + "check_content_ref": check_content_ref, + "remediation_available": True, # Assume remediation available for AEGIS } - + failed_rules.append(failed_rule) - + except Exception as e: logger.error(f"Error parsing scan results for failed rules: {e}") # Return basic info even if parsing fails pass - + response_data = { "scan_id": scan_id, "host_id": str(scan_result.host_id), @@ -1557,12 +1680,12 @@ async def get_scan_failed_rules( "total_rules": scan_result.total_rules or 0, "failed_rules_count": len(failed_rules), "compliance_score": scan_result.score, - "failed_rules": failed_rules + "failed_rules": failed_rules, } - + logger.info(f"Retrieved {len(failed_rules)} failed rules for scan {scan_id}") return response_data - + except HTTPException: raise except Exception as e: @@ -1575,27 +1698,37 @@ async def create_verification_scan( verification_request: VerificationScanRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Create a verification scan after AEGIS remediation""" try: # Validate host exists and is active - host_result = db.execute(text(""" + host_result = db.execute( + text( + """ SELECT id, display_name, hostname, port, username, auth_method, encrypted_credentials FROM hosts WHERE id = :id AND is_active = true - """), {"id": verification_request.host_id}).fetchone() - + """ + ), + {"id": verification_request.host_id}, + ).fetchone() + if not host_result: raise HTTPException(status_code=404, detail="Host not found or inactive") - + # Validate SCAP content exists - content_result = db.execute(text(""" + content_result = db.execute( + text( + """ SELECT id, name, file_path, profiles FROM scap_content WHERE id = :id - """), {"id": verification_request.content_id}).fetchone() - + """ + ), + {"id": verification_request.content_id}, + ).fetchone() + if not content_result: raise HTTPException(status_code=404, detail="SCAP content not found") - + # Validate profile exists in content profiles = [] if content_result.profiles: @@ -1606,43 +1739,48 @@ async def create_verification_scan( raise HTTPException(status_code=400, detail="Profile not found in SCAP content") except: raise HTTPException(status_code=400, detail="Invalid SCAP content profiles") - + # Generate scan name scan_name = verification_request.name or f"Verification Scan - {host_result.hostname}" if verification_request.original_scan_id: scan_name += f" (Post-Remediation)" - + # Create verification scan record scan_options = { "verification_scan": True, "original_scan_id": verification_request.original_scan_id, - "remediation_job_id": verification_request.remediation_job_id + "remediation_job_id": verification_request.remediation_job_id, } - - result = db.execute(text(""" + + result = db.execute( + text( + """ INSERT INTO scans (name, host_id, content_id, profile_id, status, progress, scan_options, started_by, started_at, verification_scan) VALUES (:name, :host_id, :content_id, :profile_id, :status, :progress, :scan_options, :started_by, :started_at, :verification_scan) RETURNING id - """), { - "name": scan_name, - "host_id": verification_request.host_id, - "content_id": verification_request.content_id, - "profile_id": verification_request.profile_id, - "status": "pending", - "progress": 0, - "scan_options": json.dumps(scan_options), - "started_by": current_user["id"], - "started_at": datetime.utcnow(), - "verification_scan": True - }) - + """ + ), + { + "name": scan_name, + "host_id": verification_request.host_id, + "content_id": verification_request.content_id, + "profile_id": verification_request.profile_id, + "status": "pending", + "progress": 0, + "scan_options": json.dumps(scan_options), + "started_by": current_user["id"], + "started_at": datetime.utcnow(), + "verification_scan": True, + }, + ) + # Get the generated scan ID scan_id = result.fetchone().id db.commit() - + # Start verification scan as background task background_tasks.add_task( execute_scan_task, @@ -1652,32 +1790,32 @@ async def create_verification_scan( "port": host_result.port, "username": host_result.username, "auth_method": host_result.auth_method, - "encrypted_credentials": host_result.encrypted_credentials + "encrypted_credentials": host_result.encrypted_credentials, }, content_path=content_result.file_path, profile_id=verification_request.profile_id, - scan_options=scan_options + scan_options=scan_options, ) - + logger.info(f"Verification scan created and started: {scan_id}") - + response = { "id": scan_id, "message": "Verification scan created and started successfully", "status": "pending", "verification_scan": True, "host_id": verification_request.host_id, - "host_name": host_result.display_name or host_result.hostname + "host_name": host_result.display_name or host_result.hostname, } - + # Add reference info if provided if verification_request.original_scan_id: response["original_scan_id"] = verification_request.original_scan_id if verification_request.remediation_job_id: response["remediation_job_id"] = verification_request.remediation_job_id - + return response - + except HTTPException: raise except Exception as e: @@ -1691,14 +1829,16 @@ async def rescan_rule( rescan_request: RuleRescanRequest, background_tasks: BackgroundTasks, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Rescan a specific rule from a completed scan""" try: logger.info(f"Rule rescan requested for scan {scan_id}, rule {rescan_request.rule_id}") - + # Get the original scan details - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT s.id, s.host_id, s.content_id, s.profile_id, s.name, h.hostname, h.ip_address, h.port, h.username, h.auth_method, h.encrypted_credentials, c.file_path, c.filename @@ -1706,45 +1846,55 @@ async def rescan_rule( JOIN hosts h ON s.host_id = h.id JOIN scap_content c ON s.content_id = c.id WHERE s.id = :scan_id - """), {"scan_id": scan_id}) - + """ + ), + {"scan_id": scan_id}, + ) + scan_data = result.fetchone() if not scan_data: raise HTTPException(status_code=404, detail="Original scan not found") - + # Validate that the host is still active if not scan_data.encrypted_credentials: raise HTTPException(status_code=400, detail="Host credentials not available") - + # Validate that the SCAP content file exists if not scan_data.file_path: raise HTTPException(status_code=400, detail="SCAP content file not found") - + # Create a new scan record for the rule rescan scan_name = rescan_request.name or f"Rule Rescan: {rescan_request.rule_id}" - - result = db.execute(text(""" + + result = db.execute( + text( + """ INSERT INTO scans (name, host_id, content_id, profile_id, status, progress, started_by, started_at, scan_options) VALUES (:name, :host_id, :content_id, :profile_id, :status, :progress, :started_by, :started_at, :scan_options) RETURNING id - """), { - "name": scan_name, - "host_id": scan_data.host_id, - "content_id": scan_data.content_id, - "profile_id": scan_data.profile_id, - "status": "pending", - "progress": 0, - "started_by": current_user["id"], - "started_at": datetime.utcnow(), - "scan_options": json.dumps({"rule_id": rescan_request.rule_id, "rescan_type": "rule"}) - }) - + """ + ), + { + "name": scan_name, + "host_id": scan_data.host_id, + "content_id": scan_data.content_id, + "profile_id": scan_data.profile_id, + "status": "pending", + "progress": 0, + "started_by": current_user["id"], + "started_at": datetime.utcnow(), + "scan_options": json.dumps( + {"rule_id": rescan_request.rule_id, "rescan_type": "rule"} + ), + }, + ) + new_scan_id = result.fetchone()[0] - + db.commit() - + # Prepare host data for scan execution host_data = { "id": scan_data.host_id, @@ -1753,9 +1903,9 @@ async def rescan_rule( "port": scan_data.port, "username": scan_data.username, "auth_method": scan_data.auth_method, - "encrypted_credentials": scan_data.encrypted_credentials # This will be decrypted by the task + "encrypted_credentials": scan_data.encrypted_credentials, # This will be decrypted by the task } - + # Execute the rule-specific scan as background task scan_options = {"rule_id": rescan_request.rule_id, "rescan_type": "rule"} background_tasks.add_task( @@ -1764,19 +1914,19 @@ async def rescan_rule( host_data, scan_data.file_path, scan_data.profile_id, - scan_options + scan_options, ) - + logger.info(f"Rule rescan task scheduled: {new_scan_id}") - + return { "message": "Rule rescan initiated successfully", "scan_id": str(new_scan_id), "status": "pending", "rule_id": rescan_request.rule_id, - "original_scan_id": scan_id + "original_scan_id": scan_id, } - + except HTTPException: raise except Exception as e: @@ -1786,52 +1936,66 @@ async def rescan_rule( @router.post("/{scan_id}/remediate") async def start_remediation( - scan_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + scan_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Send failed rules to AEGIS for automated remediation""" try: # Get scan details and failed rules - scan_result = db.execute(text(""" + scan_result = db.execute( + text( + """ SELECT s.id, s.name, s.host_id, h.hostname, h.ip_address, sr.failed_rules, sr.severity_high, sr.severity_medium, sr.severity_low FROM scans s JOIN hosts h ON s.host_id = h.id LEFT JOIN scan_results sr ON s.id = sr.scan_id WHERE s.id = :scan_id AND s.status = 'completed' - """), {"scan_id": scan_id}).fetchone() - + """ + ), + {"scan_id": scan_id}, + ).fetchone() + if not scan_result: raise HTTPException(status_code=404, detail="Completed scan not found") - + if scan_result.failed_rules == 0: raise HTTPException(status_code=400, detail="No failed rules to remediate") - + # Get the actual failed rules - failed_rules = db.execute(text(""" + failed_rules = db.execute( + text( + """ SELECT rule_id, title, severity, description FROM scan_rule_results WHERE scan_id = :scan_id AND status = 'failed' ORDER BY CASE severity WHEN 'high' THEN 1 WHEN 'medium' THEN 2 ELSE 3 END - """), {"scan_id": scan_id}).fetchall() - + """ + ), + {"scan_id": scan_id}, + ).fetchall() + # Mock AEGIS integration - in reality this would call AEGIS API import uuid + remediation_job_id = str(uuid.uuid4()) - + # Update scan with remediation request - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET remediation_requested = true, aegis_remediation_id = :job_id, remediation_status = 'pending' WHERE id = :scan_id - """), {"scan_id": scan_id, "job_id": remediation_job_id}) + """ + ), + {"scan_id": scan_id, "job_id": remediation_job_id}, + ) db.commit() - + logger.info(f"Remediation job created: {remediation_job_id} for scan {scan_id}") - + return { "job_id": remediation_job_id, "message": f"Remediation job created for {len(failed_rules)} failed rules", @@ -1841,13 +2005,13 @@ async def start_remediation( "severity_breakdown": { "high": scan_result.severity_high, "medium": scan_result.severity_medium, - "low": scan_result.severity_low + "low": scan_result.severity_low, }, - "status": "pending" + "status": "pending", } - + except HTTPException: raise except Exception as e: logger.error(f"Error starting remediation for scan {scan_id}: {e}") - raise HTTPException(status_code=500, detail="Failed to start remediation job") \ No newline at end of file + raise HTTPException(status_code=500, detail="Failed to start remediation job") diff --git a/backend/app/routes/scap_content.py b/backend/app/routes/scap_content.py index 49a10328..585ebc05 100644 --- a/backend/app/routes/scap_content.py +++ b/backend/app/routes/scap_content.py @@ -2,6 +2,7 @@ SCAP Content Management API Routes Handles SCAP content upload, validation, and management """ + import os import hashlib import tempfile @@ -32,50 +33,57 @@ datastream_processor = SCAPDataStreamProcessor() framework_mapper = ComplianceFrameworkMapper() + @router.get("/") async def list_scap_content( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """List all uploaded SCAP content""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, name, filename, content_type, description, version, profiles, uploaded_at, uploaded_by, file_path FROM scap_content ORDER BY uploaded_at DESC - """)) - + """ + ) + ) + content_list = [] for row in result: profiles = [] if row.profiles: try: import json + profiles = json.loads(row.profiles) except: profiles = [] - - content_list.append({ - "id": row.id, - "name": row.name, - "filename": row.filename, - "content_type": row.content_type, - "description": row.description, - "version": row.version, - "profiles": profiles, - "uploaded_at": row.uploaded_at.isoformat(), - "uploaded_by": row.uploaded_by, - "os_family": "unknown", - "os_version": "unknown", - "compliance_framework": "unknown", - "source": "manual", - "status": "current", - "update_available": False - }) - + + content_list.append( + { + "id": row.id, + "name": row.name, + "filename": row.filename, + "content_type": row.content_type, + "description": row.description, + "version": row.version, + "profiles": profiles, + "uploaded_at": row.uploaded_at.isoformat(), + "uploaded_by": row.uploaded_by, + "os_family": "unknown", + "os_version": "unknown", + "compliance_framework": "unknown", + "source": "manual", + "status": "current", + "update_available": False, + } + ) + return {"scap_content": content_list} - + except Exception as e: logger.error(f"Error listing SCAP content: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve SCAP content") @@ -83,46 +91,55 @@ async def list_scap_content( @router.get("/statistics") async def get_scap_content_stats( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get SCAP content statistics""" try: # Get content counts - simplified since we don't have os_family, status, etc. columns - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT COUNT(*) as total_content FROM scap_content - """)).fetchone() - + """ + ) + ).fetchone() + total_content = result.total_content if result else 0 - + # Simplified os_stats since we don't have os_family column - os_stats = [{ - "os_family": "unknown", - "total_content": total_content, - "versions": 1, - "outdated": 0, - "updates_available": 0, - "total_profiles": 0 - }] if total_content > 0 else [] - + os_stats = ( + [ + { + "os_family": "unknown", + "total_content": total_content, + "versions": 1, + "outdated": 0, + "updates_available": 0, + "total_profiles": 0, + } + ] + if total_content > 0 + else [] + ) + # Get overall statistics - simplified overall_result = { "total_content": total_content, "os_types": 1 if total_content > 0 else 0, "frameworks": 1 if total_content > 0 else 0, "outdated": 0, - "updates_available": 0 + "updates_available": 0, } - + return { "overall": { "total_content": overall_result["total_content"], "os_types": overall_result["os_types"], "frameworks": overall_result["frameworks"], "outdated": overall_result["outdated"], - "updates_available": overall_result["updates_available"] + "updates_available": overall_result["updates_available"], }, - "by_os_family": os_stats + "by_os_family": os_stats, } except Exception as e: logger.error(f"Error getting SCAP content stats: {e}") @@ -135,91 +152,109 @@ async def upload_scap_content( name: str = Form(...), description: str = Form(""), db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Upload and validate SCAP content file""" try: # Validate file type - allowed_extensions = ['.xml', '.zip'] + allowed_extensions = [".xml", ".zip"] file_ext = Path(file.filename).suffix.lower() if file_ext not in allowed_extensions: raise HTTPException( - status_code=400, - detail=f"Invalid file type. Allowed: {', '.join(allowed_extensions)}" + status_code=400, + detail=f"Invalid file type. Allowed: {', '.join(allowed_extensions)}", ) - + # Read file content content = await file.read() if len(content) == 0: raise HTTPException(status_code=400, detail="Empty file uploaded") - + # Calculate file hash file_hash = hashlib.sha256(content).hexdigest() - + # Check if file already exists - existing = db.execute(text(""" + existing = db.execute( + text( + """ SELECT id FROM scap_content WHERE file_hash = :hash - """), {"hash": file_hash}).fetchone() - + """ + ), + {"hash": file_hash}, + ).fetchone() + if existing: raise HTTPException(status_code=409, detail="File already exists") - + # Save file to temporary location for validation with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file: temp_file.write(content) temp_path = temp_file.name - + try: # Validate SCAP content using data-stream processor validation_result = datastream_processor.validate_datastream(temp_path) - + # Extract profiles with metadata profiles = datastream_processor.extract_profiles_with_metadata(temp_path) - + # Extract content components for framework mapping content_components = datastream_processor.extract_content_components(temp_path) - + # Create permanent storage location content_id = str(uuid.uuid4()) storage_dir = Path("/app/data/scap") / content_id storage_dir.mkdir(parents=True, exist_ok=True) - + permanent_path = storage_dir / file.filename - with open(permanent_path, 'wb') as f: + with open(permanent_path, "wb") as f: f.write(content) - + # Extract OS and framework information os_family, os_version = _extract_os_info(file.filename, validation_result) compliance_framework = _extract_framework_info(file.filename, validation_result) - + # Save to database with complete metadata import json - db.execute(text(""" + + db.execute( + text( + """ INSERT INTO scap_content (name, filename, file_path, content_type, profiles, description, version, uploaded_by, file_hash, uploaded_at) VALUES (:name, :filename, :file_path, :content_type, :profiles, :description, :version, :uploaded_by, :file_hash, NOW()) - """), { - "name": name, - "filename": file.filename, - "file_path": str(permanent_path), - "content_type": validation_result.get("content_type", validation_result.get("document_type", "unknown")), - "profiles": json.dumps(profiles), - "description": description, - "version": validation_result.get("version", ""), - "uploaded_by": current_user["id"], - "file_hash": file_hash - }) + """ + ), + { + "name": name, + "filename": file.filename, + "file_path": str(permanent_path), + "content_type": validation_result.get( + "content_type", validation_result.get("document_type", "unknown") + ), + "profiles": json.dumps(profiles), + "description": description, + "version": validation_result.get("version", ""), + "uploaded_by": current_user["id"], + "file_hash": file_hash, + }, + ) db.commit() - + # Get the inserted record - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id FROM scap_content WHERE file_hash = :hash - """), {"hash": file_hash}).fetchone() - + """ + ), + {"hash": file_hash}, + ).fetchone() + logger.info(f"SCAP content uploaded: {name} ({file.filename})") - + return { "id": result.id, "message": "SCAP content uploaded successfully", @@ -229,18 +264,18 @@ async def upload_scap_content( "format": content_components.get("format", "unknown"), "rules_count": len(content_components.get("rules", [])), "os_family": "unknown", - "os_version": "unknown", - "compliance_framework": "unknown" - } + "os_version": "unknown", + "compliance_framework": "unknown", + }, } - + finally: # Clean up temp file try: os.unlink(temp_path) except: pass - + except SCAPContentError as e: logger.error(f"SCAP validation error: {e}") raise HTTPException(status_code=400, detail=str(e)) @@ -254,29 +289,33 @@ async def upload_scap_content( @router.get("/{content_id}") async def get_scap_content( - content_id: int, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + content_id: int, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get SCAP content details""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, name, filename, content_type, description, version, profiles, uploaded_at, uploaded_by, file_path FROM scap_content WHERE id = :id - """), {"id": content_id}).fetchone() - + """ + ), + {"id": content_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="SCAP content not found") - + profiles = [] if result.profiles: try: import json + profiles = json.loads(result.profiles) except: profiles = [] - + return { "id": result.id, "name": result.name, @@ -287,9 +326,9 @@ async def get_scap_content( "profiles": profiles, "uploaded_at": result.uploaded_at.isoformat(), "uploaded_by": result.uploaded_by, - "has_file": os.path.exists(result.file_path) + "has_file": os.path.exists(result.file_path), } - + except HTTPException: raise except Exception as e: @@ -299,31 +338,35 @@ async def get_scap_content( @router.get("/{content_id}/profiles") async def get_scap_profiles( - content_id: int, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + content_id: int, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get available profiles for SCAP content""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT profiles, file_path FROM scap_content WHERE id = :id - """), {"id": content_id}).fetchone() - + """ + ), + {"id": content_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="SCAP content not found") - + profiles = [] if result.profiles: try: import json + profiles = json.loads(result.profiles) except: # Re-extract profiles from file if cached version is invalid if os.path.exists(result.file_path): profiles = scap_scanner.extract_profiles(result.file_path) - + return {"profiles": profiles} - + except HTTPException: raise except Exception as e: @@ -333,46 +376,56 @@ async def get_scap_profiles( @router.delete("/{content_id}") async def delete_scap_content( - content_id: int, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + content_id: int, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Delete SCAP content""" try: # Check if content exists and get file path - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT file_path FROM scap_content WHERE id = :id - """), {"id": content_id}).fetchone() - + """ + ), + {"id": content_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="SCAP content not found") - + # Check for any scans using this content (active or completed) - scan_count = db.execute(text(""" + scan_count = db.execute( + text( + """ SELECT COUNT(*) as count FROM scans WHERE content_id = :id - """), {"id": content_id}).fetchone() - + """ + ), + {"id": content_id}, + ).fetchone() + if scan_count.count > 0: # Get scan details for better error message - scan_details = db.execute(text(""" + scan_details = db.execute( + text( + """ SELECT status, COUNT(*) as count FROM scans WHERE content_id = :id GROUP BY status ORDER BY status - """), {"id": content_id}).fetchall() - + """ + ), + {"id": content_id}, + ).fetchall() + status_summary = [] for row in scan_details: status_summary.append(f"{row.count} {row.status}") - + detail_msg = f"Cannot delete SCAP content that has {scan_count.count} associated scan(s): {', '.join(status_summary)}. Please delete the scans first or contact an administrator." - - raise HTTPException( - status_code=409, - detail=detail_msg - ) - + + raise HTTPException(status_code=409, detail=detail_msg) + # Delete file from storage file_path = result.file_path if os.path.exists(file_path): @@ -387,76 +440,86 @@ async def delete_scap_content( pass except Exception as e: logger.warning(f"Failed to delete file {file_path}: {e}") - + # Delete from database - db.execute(text(""" + db.execute( + text( + """ DELETE FROM scap_content WHERE id = :id - """), {"id": content_id}) + """ + ), + {"id": content_id}, + ) db.commit() - + logger.info(f"SCAP content deleted: {content_id}") return {"message": "SCAP content deleted successfully"} - + except HTTPException: # Re-raise HTTPExceptions (404, 409) with their specific messages raise except IntegrityError as e: # Handle database constraint violations db.rollback() - error_msg = str(e.orig) if hasattr(e, 'orig') else str(e) - + error_msg = str(e.orig) if hasattr(e, "orig") else str(e) + if "foreign key constraint" in error_msg.lower(): - logger.warning(f"Foreign key constraint violation when deleting SCAP content {content_id}: {error_msg}") + logger.warning( + f"Foreign key constraint violation when deleting SCAP content {content_id}: {error_msg}" + ) raise HTTPException( status_code=409, - detail="Cannot delete SCAP content because it is referenced by existing scan results. Please delete associated scans first." + detail="Cannot delete SCAP content because it is referenced by existing scan results. Please delete associated scans first.", ) else: - logger.error(f"Database integrity error when deleting SCAP content {content_id}: {error_msg}") + logger.error( + f"Database integrity error when deleting SCAP content {content_id}: {error_msg}" + ) raise HTTPException( status_code=500, - detail="Database constraint violation prevented deletion. Please contact an administrator." + detail="Database constraint violation prevented deletion. Please contact an administrator.", ) except OSError as e: # Handle file system errors logger.error(f"File system error when deleting SCAP content {content_id}: {e}") raise HTTPException( status_code=500, - detail="Failed to delete SCAP content files from storage. The database record may have been removed." + detail="Failed to delete SCAP content files from storage. The database record may have been removed.", ) except Exception as e: # Handle any other unexpected errors logger.error(f"Unexpected error deleting SCAP content {content_id}: {e}", exc_info=True) raise HTTPException( status_code=500, - detail="An unexpected error occurred during deletion. Please try again or contact an administrator." + detail="An unexpected error occurred during deletion. Please try again or contact an administrator.", ) @router.get("/{content_id}/download") async def download_scap_content( - content_id: int, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + content_id: int, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Download SCAP content file""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT filename, file_path FROM scap_content WHERE id = :id - """), {"id": content_id}).fetchone() - + """ + ), + {"id": content_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="SCAP content not found") - + if not os.path.exists(result.file_path): raise HTTPException(status_code=404, detail="SCAP content file not found") - + return FileResponse( - path=result.file_path, - filename=result.filename, - media_type='application/octet-stream' + path=result.file_path, filename=result.filename, media_type="application/octet-stream" ) - + except HTTPException: raise except Exception as e: @@ -466,10 +529,9 @@ async def download_scap_content( # Repository Management Endpoints + @router.get("/repositories/status") -async def get_repository_status( - current_user: dict = Depends(get_current_user) -): +async def get_repository_status(current_user: dict = Depends(get_current_user)): """Get status of all SCAP repositories""" try: status = scap_repository_manager.get_repository_status() @@ -483,20 +545,17 @@ async def get_repository_status( async def sync_repositories( repository_ids: Optional[List[str]] = None, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Synchronize SCAP content from repositories""" try: if scap_repository_manager.sync_running: raise HTTPException(status_code=409, detail="Sync already in progress") - + # Start sync in background results = await scap_repository_manager.sync_repositories(db, repository_ids) - - return { - "message": "Repository sync completed", - "results": results - } + + return {"message": "Repository sync completed", "results": results} except HTTPException: raise except Exception as e: @@ -507,97 +566,100 @@ async def sync_repositories( def _extract_os_info(filename: str, validation_result: Dict) -> Tuple[str, str]: """Extract OS family and version from filename and validation result""" filename_lower = filename.lower() - + # Check filename patterns - if 'rhel' in filename_lower: - os_family = 'rhel' - if 'rhel_9' in filename_lower or 'rhel9' in filename_lower: - os_version = '9' - elif 'rhel_8' in filename_lower or 'rhel8' in filename_lower: - os_version = '8' + if "rhel" in filename_lower: + os_family = "rhel" + if "rhel_9" in filename_lower or "rhel9" in filename_lower: + os_version = "9" + elif "rhel_8" in filename_lower or "rhel8" in filename_lower: + os_version = "8" else: - os_version = 'unknown' - elif 'ubuntu' in filename_lower: - os_family = 'ubuntu' - if '22.04' in filename_lower or '22_04' in filename_lower or '2204' in filename_lower: - os_version = '22.04' - elif '20.04' in filename_lower or '20_04' in filename_lower or '2004' in filename_lower: - os_version = '20.04' + os_version = "unknown" + elif "ubuntu" in filename_lower: + os_family = "ubuntu" + if "22.04" in filename_lower or "22_04" in filename_lower or "2204" in filename_lower: + os_version = "22.04" + elif "20.04" in filename_lower or "20_04" in filename_lower or "2004" in filename_lower: + os_version = "20.04" else: - os_version = 'unknown' - elif 'oracle' in filename_lower: - os_family = 'oracle_linux' - if 'oracle_linux_8' in filename_lower: - os_version = '8' + os_version = "unknown" + elif "oracle" in filename_lower: + os_family = "oracle_linux" + if "oracle_linux_8" in filename_lower: + os_version = "8" else: - os_version = 'unknown' - elif 'centos' in filename_lower: - os_family = 'centos' - os_version = 'unknown' + os_version = "unknown" + elif "centos" in filename_lower: + os_family = "centos" + os_version = "unknown" else: # Try to extract from validation result - title = validation_result.get('title', '').lower() - if 'red hat' in title: - os_family = 'rhel' - elif 'ubuntu' in title: - os_family = 'ubuntu' - elif 'oracle' in title: - os_family = 'oracle_linux' + title = validation_result.get("title", "").lower() + if "red hat" in title: + os_family = "rhel" + elif "ubuntu" in title: + os_family = "ubuntu" + elif "oracle" in title: + os_family = "oracle_linux" else: - os_family = 'unknown' - os_version = 'unknown' - + os_family = "unknown" + os_version = "unknown" + return os_family, os_version def _extract_framework_info(filename: str, validation_result: Dict) -> str: """Extract compliance framework from filename and validation result""" filename_lower = filename.lower() - title_lower = validation_result.get('title', '').lower() - + title_lower = validation_result.get("title", "").lower() + # Check for framework indicators - if 'stig' in filename_lower or 'stig' in title_lower or 'disa' in filename_lower: - return 'DISA-STIG' - elif 'cis' in filename_lower or 'cis' in title_lower: - return 'CIS-Controls' - elif 'nist' in filename_lower or 'nist' in title_lower: - return 'NIST-800-53' - elif 'pci' in filename_lower or 'pci-dss' in title_lower: - return 'PCI-DSS' - elif 'hipaa' in filename_lower: - return 'HIPAA' - elif 'cmmc' in filename_lower: - return 'CMMC-2.0' + if "stig" in filename_lower or "stig" in title_lower or "disa" in filename_lower: + return "DISA-STIG" + elif "cis" in filename_lower or "cis" in title_lower: + return "CIS-Controls" + elif "nist" in filename_lower or "nist" in title_lower: + return "NIST-800-53" + elif "pci" in filename_lower or "pci-dss" in title_lower: + return "PCI-DSS" + elif "hipaa" in filename_lower: + return "HIPAA" + elif "cmmc" in filename_lower: + return "CMMC-2.0" else: - return 'unknown' + return "unknown" @router.get("/{content_id}/compliance-analysis") async def get_compliance_analysis( - content_id: int, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + content_id: int, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get compliance framework analysis for SCAP content""" try: # Get content info - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT file_path, profiles FROM scap_content WHERE id = :id - """), {"id": content_id}).fetchone() - + """ + ), + {"id": content_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="SCAP content not found") - + # Extract content components content_components = datastream_processor.extract_content_components(result.file_path) - + # Get compliance framework summary rule_ids = [rule["id"] for rule in content_components.get("rules", [])] framework_summary = framework_mapper.get_framework_summary(rule_ids) - + # Get compliance matrix compliance_matrix = framework_mapper.export_compliance_matrix(rule_ids) - + return { "content_id": content_id, "total_rules": len(rule_ids), @@ -606,10 +668,10 @@ async def get_compliance_analysis( "content_components": { "format": content_components.get("format"), "profiles": content_components.get("profiles", []), - "data_streams": content_components.get("data_streams", []) - } + "data_streams": content_components.get("data_streams", []), + }, } - + except Exception as e: logger.error(f"Error analyzing compliance content: {e}") raise HTTPException(status_code=500, detail="Failed to analyze compliance content") @@ -617,29 +679,32 @@ async def get_compliance_analysis( @router.post("/{content_id}/validate-datastream") async def validate_datastream_content( - content_id: int, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + content_id: int, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Create comprehensive validation report for SCAP content""" try: # Get content file path - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT file_path, filename FROM scap_content WHERE id = :id - """), {"id": content_id}).fetchone() - + """ + ), + {"id": content_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="SCAP content not found") - + # Create validation report validation_report = datastream_processor.create_content_validation_report(result.file_path) - + return { "content_id": content_id, "filename": result.filename, - "validation_report": validation_report + "validation_report": validation_report, } - + except Exception as e: logger.error(f"Error validating datastream content: {e}") raise HTTPException(status_code=500, detail="Failed to validate datastream content") @@ -647,8 +712,7 @@ async def validate_datastream_content( @router.get("/framework-mappings") async def get_framework_mappings( - framework: Optional[str] = None, - current_user: dict = Depends(get_current_user) + framework: Optional[str] = None, current_user: dict = Depends(get_current_user) ): """Get compliance framework mappings""" try: @@ -658,9 +722,11 @@ async def get_framework_mappings( "frameworks": list(framework_mapper.control_families.keys()), "mappings_available": len(framework_mapper.framework_mappings), "supported_platforms": ["rhel8", "rhel9", "ubuntu20", "ubuntu22"], - "framework_info": framework_mapper.control_families.get(framework) if framework else None + "framework_info": ( + framework_mapper.control_families.get(framework) if framework else None + ), } - + except Exception as e: logger.error(f"Error getting framework mappings: {e}") raise HTTPException(status_code=500, detail="Failed to get framework mappings") @@ -668,57 +734,55 @@ async def get_framework_mappings( @router.put("/repositories/{repository_id}/enable") async def enable_repository( - repository_id: str, - enabled: bool = True, - current_user: dict = Depends(get_current_user) + repository_id: str, enabled: bool = True, current_user: dict = Depends(get_current_user) ): """Enable or disable a repository""" try: scap_repository_manager.enable_repository(repository_id, enabled) - return { - "message": f"Repository {repository_id} {'enabled' if enabled else 'disabled'}" - } + return {"message": f"Repository {repository_id} {'enabled' if enabled else 'disabled'}"} except Exception as e: logger.error(f"Error updating repository: {e}") raise HTTPException(status_code=500, detail="Failed to update repository") @router.get("/environment/info") -async def get_environment_info( - current_user: dict = Depends(get_current_user) -): +async def get_environment_info(current_user: dict = Depends(get_current_user)): """Get environment information (connected/air-gapped)""" try: # Determine environment type based on repository connectivity repositories = scap_repository_manager.get_repository_status()["repositories"] - + # Simple connectivity test has_internet = False try: import asyncio import aiohttp - + async def test_connectivity(): try: async with aiohttp.ClientSession() as session: - async with session.get('https://www.google.com', timeout=5) as response: + async with session.get("https://www.google.com", timeout=5) as response: return response.status == 200 except: return False - + has_internet = await test_connectivity() except: has_internet = False - - environment_type = 'connected' if has_internet else 'air-gapped' - + + environment_type = "connected" if has_internet else "air-gapped" + return { "type": environment_type, "repositories": repositories, "auto_sync_enabled": has_internet and any(r["enabled"] for r in repositories), - "last_global_sync": scap_repository_manager.last_global_sync.isoformat() if scap_repository_manager.last_global_sync else None, - "next_scheduled_sync": None # Would be calculated based on schedule settings + "last_global_sync": ( + scap_repository_manager.last_global_sync.isoformat() + if scap_repository_manager.last_global_sync + else None + ), + "next_scheduled_sync": None, # Would be calculated based on schedule settings } except Exception as e: logger.error(f"Error getting environment info: {e}") - raise HTTPException(status_code=500, detail="Failed to get environment information") \ No newline at end of file + raise HTTPException(status_code=500, detail="Failed to get environment information") diff --git a/backend/app/routes/security_config.py b/backend/app/routes/security_config.py index 59dc8e89..ea3d67e5 100644 --- a/backend/app/routes/security_config.py +++ b/backend/app/routes/security_config.py @@ -16,14 +16,14 @@ from ..auth import get_current_user from ..rbac import require_permission, Permission from ..services.security_config import ( - SecurityConfigManager, - ConfigScope, - get_security_config_manager + SecurityConfigManager, + ConfigScope, + get_security_config_manager, ) from ..services.credential_validation import ( - SecurityPolicyLevel, - SecurityPolicyConfig, - get_credential_validator + SecurityPolicyLevel, + SecurityPolicyConfig, + get_credential_validator, ) logger = logging.getLogger(__name__) @@ -34,6 +34,7 @@ # Pydantic models class SecurityPolicyRequest(BaseModel): """Request model for security policy configuration""" + policy_level: SecurityPolicyLevel = Field(..., description="Security policy enforcement level") enforce_fips: bool = Field(True, description="Enforce FIPS 140-2 compliance") minimum_rsa_bits: int = Field(3072, description="Minimum RSA key size in bits") @@ -45,6 +46,7 @@ class SecurityPolicyRequest(BaseModel): class SecurityConfigResponse(BaseModel): """Response model for security configuration""" + scope: str target_id: Optional[str] effective_config: Dict[str, Any] @@ -55,6 +57,7 @@ class SecurityConfigResponse(BaseModel): class TemplateResponse(BaseModel): """Response model for security templates""" + name: str description: str policy_level: str @@ -64,6 +67,7 @@ class TemplateResponse(BaseModel): class ValidationResponse(BaseModel): """Response model for SSH key validation""" + is_valid: bool is_secure: bool is_fips_compliant: bool @@ -82,7 +86,7 @@ async def get_security_config( target_id: Optional[str] = None, target_type: Optional[str] = None, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Get effective security configuration for a target. @@ -91,19 +95,19 @@ async def get_security_config( try: config_manager = get_security_config_manager(db) summary = config_manager.get_config_summary(target_id, target_type) - + if "error" in summary: raise HTTPException(status_code=500, detail=summary["error"]) - + return SecurityConfigResponse( scope=target_type or "system", target_id=target_id, effective_config=summary["effective_config"], inheritance_chain=summary["inheritance_chain"], compliance_level=summary["compliance_level"], - last_updated=datetime.utcnow().isoformat() + last_updated=datetime.utcnow().isoformat(), ) - + except Exception as e: logger.error(f"Failed to get security config: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve security configuration") @@ -116,14 +120,14 @@ async def update_security_config( scope: ConfigScope, target_id: Optional[str] = None, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Update security configuration for a specific scope. """ try: config_manager = get_security_config_manager(db) - + # Convert request to SecurityPolicyConfig config = SecurityPolicyConfig( policy_level=policy.policy_level, @@ -132,31 +136,32 @@ async def update_security_config( minimum_ecdsa_bits=policy.minimum_ecdsa_bits, allow_dsa_keys=policy.allow_dsa_keys, minimum_password_length=policy.minimum_password_length, - require_complex_passwords=policy.require_complex_passwords + require_complex_passwords=policy.require_complex_passwords, ) - + # Validate configuration is_valid, validation_messages = config_manager.validate_config(config) if not is_valid: raise HTTPException( - status_code=400, - detail=f"Invalid configuration: {'; '.join(validation_messages)}" + status_code=400, detail=f"Invalid configuration: {'; '.join(validation_messages)}" ) - + # Update configuration success = config_manager.set_config( scope=scope, config=config, target_id=target_id, - created_by=current_user.get('id', 'unknown') + created_by=current_user.get("id", "unknown"), ) - + if success: - logger.info(f"Security config updated by {current_user.get('username')} for {scope.value}:{target_id}") + logger.info( + f"Security config updated by {current_user.get('username')} for {scope.value}:{target_id}" + ) return {"message": "Security configuration updated successfully"} else: raise HTTPException(status_code=500, detail="Failed to update security configuration") - + except HTTPException: raise except Exception as e: @@ -171,27 +176,31 @@ async def apply_security_template( scope: ConfigScope, target_id: Optional[str] = None, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Apply a predefined security configuration template. """ try: config_manager = get_security_config_manager(db) - + success = config_manager.apply_template( template_name=template_name, scope=scope, target_id=target_id, - created_by=current_user.get('id', 'unknown') + created_by=current_user.get("id", "unknown"), ) - + if success: - logger.info(f"Applied template '{template_name}' by {current_user.get('username')} to {scope.value}:{target_id}") + logger.info( + f"Applied template '{template_name}' by {current_user.get('username')} to {scope.value}:{target_id}" + ) return {"message": f"Template '{template_name}' applied successfully"} else: - raise HTTPException(status_code=400, detail=f"Failed to apply template '{template_name}'") - + raise HTTPException( + status_code=400, detail=f"Failed to apply template '{template_name}'" + ) + except HTTPException: raise except Exception as e: @@ -201,8 +210,7 @@ async def apply_security_template( @router.get("/templates", response_model=List[TemplateResponse]) async def list_security_templates( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """ List all available security configuration templates. @@ -210,9 +218,9 @@ async def list_security_templates( try: config_manager = get_security_config_manager(db) templates = config_manager.list_templates() - + return [TemplateResponse(**template) for template in templates] - + except Exception as e: logger.error(f"Failed to list security templates: {e}") raise HTTPException(status_code=500, detail="Failed to list security templates") @@ -229,7 +237,7 @@ class SSHKeyValidationRequest(BaseModel): async def validate_ssh_key( request: SSHKeyValidationRequest, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Validate SSH key against current security policies. @@ -238,17 +246,18 @@ async def validate_ssh_key( try: # Get effective configuration for the target config_manager = get_security_config_manager(db) - effective_config = config_manager.get_effective_config(request.target_id, request.target_type) - + effective_config = config_manager.get_effective_config( + request.target_id, request.target_type + ) + # Create validator with effective configuration validator = get_credential_validator( - policy_level=effective_config.policy_level, - enforce_fips=effective_config.enforce_fips + policy_level=effective_config.policy_level, enforce_fips=effective_config.enforce_fips ) - + # Perform strict validation assessment = validator.validate_ssh_key_strict(request.key_content, request.passphrase) - + return ValidationResponse( is_valid=assessment.is_valid, is_secure=assessment.is_secure, @@ -259,9 +268,9 @@ async def validate_ssh_key( error_message=assessment.error_message, warnings=assessment.warnings, recommendations=assessment.recommendations, - compliance_notes=assessment.compliance_notes + compliance_notes=assessment.compliance_notes, ) - + except Exception as e: logger.error(f"SSH key validation error: {e}") raise HTTPException(status_code=500, detail="SSH key validation failed") @@ -277,7 +286,7 @@ async def audit_credential( target_id: Optional[str] = None, target_type: Optional[str] = None, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Perform comprehensive security audit of credentials. @@ -286,26 +295,22 @@ async def audit_credential( # Get effective configuration config_manager = get_security_config_manager(db) effective_config = config_manager.get_effective_config(target_id, target_type) - + # Create validator with effective configuration validator = get_credential_validator( - policy_level=effective_config.policy_level, - enforce_fips=effective_config.enforce_fips + policy_level=effective_config.policy_level, enforce_fips=effective_config.enforce_fips ) - + # Perform audit audit_result = validator.audit_credential_security( - username=username, - auth_method=auth_method, - private_key=private_key, - password=password + username=username, auth_method=auth_method, private_key=private_key, password=password ) - + # Log audit activity logger.info(f"Credential audit performed by {current_user.get('username')} for {username}") - + return audit_result - + except Exception as e: logger.error(f"Credential audit error: {e}") raise HTTPException(status_code=500, detail="Credential audit failed") @@ -314,28 +319,27 @@ async def audit_credential( @router.get("/compliance/summary") @require_permission(Permission.AUDIT_READ) async def get_compliance_summary( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """ Get system-wide compliance summary. """ try: config_manager = get_security_config_manager(db) - + # Get system configuration system_summary = config_manager.get_config_summary() - + # TODO: Add credential compliance statistics # This would scan all stored credentials for compliance - + return { "system_config": system_summary, "compliance_level": system_summary.get("compliance_level", "unknown"), "last_updated": datetime.utcnow().isoformat(), - "assessed_by": current_user.get('username') + "assessed_by": current_user.get("username"), } - + except Exception as e: logger.error(f"Compliance summary error: {e}") - raise HTTPException(status_code=500, detail="Failed to generate compliance summary") \ No newline at end of file + raise HTTPException(status_code=500, detail="Failed to generate compliance summary") diff --git a/backend/app/routes/system_settings.py b/backend/app/routes/system_settings.py index 954d8835..a9f70f25 100644 --- a/backend/app/routes/system_settings.py +++ b/backend/app/routes/system_settings.py @@ -2,6 +2,7 @@ System Settings API Routes Handles system-wide configuration including SSH credentials """ + from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from sqlalchemy import text @@ -64,41 +65,46 @@ class SystemCredentialsResponse(SystemCredentialsBase): @router.get("/credentials", response_model=List[SystemCredentialsResponse]) @require_permission(Permission.SYSTEM_CREDENTIALS) async def list_system_credentials( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """List all system credentials (admin only)""" try: - - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT id, name, description, username, auth_method, is_default, is_active, created_at, updated_at, ssh_key_fingerprint, ssh_key_type, ssh_key_bits, ssh_key_comment FROM system_credentials WHERE is_active = true ORDER BY is_default DESC, name ASC - """)) - + """ + ) + ) + credentials = [] for row in result: - credentials.append(SystemCredentialsResponse( - id=row.id, - name=row.name, - description=row.description, - username=row.username, - auth_method=row.auth_method, - is_default=row.is_default, - is_active=row.is_active, - created_at=row.created_at.isoformat(), - updated_at=row.updated_at.isoformat(), - ssh_key_fingerprint=row.ssh_key_fingerprint, - ssh_key_type=row.ssh_key_type, - ssh_key_bits=row.ssh_key_bits, - ssh_key_comment=row.ssh_key_comment - )) - + credentials.append( + SystemCredentialsResponse( + id=row.id, + name=row.name, + description=row.description, + username=row.username, + auth_method=row.auth_method, + is_default=row.is_default, + is_active=row.is_active, + created_at=row.created_at.isoformat(), + updated_at=row.updated_at.isoformat(), + ssh_key_fingerprint=row.ssh_key_fingerprint, + ssh_key_type=row.ssh_key_type, + ssh_key_bits=row.ssh_key_bits, + ssh_key_comment=row.ssh_key_comment, + ) + ) + return credentials - + except Exception as e: logger.error(f"Error listing system credentials: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve system credentials") @@ -109,69 +115,85 @@ async def list_system_credentials( async def create_system_credentials( credentials: SystemCredentialsCreate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Create new system credentials (admin only)""" try: - + # If setting as default, unset other defaults if credentials.is_default: - db.execute(text(""" + db.execute( + text( + """ UPDATE system_credentials SET is_default = false WHERE is_default = true - """)) - + """ + ) + ) + # Validate SSH key if provided if credentials.private_key and credentials.auth_method in ["ssh_key", "both"]: logger.info(f"Validating SSH key for system credentials '{credentials.name}'") validation_result = validate_ssh_key(credentials.private_key) - + if not validation_result.is_valid: - logger.error(f"SSH key validation failed for system credentials '{credentials.name}': {validation_result.error_message}") + logger.error( + f"SSH key validation failed for system credentials '{credentials.name}': {validation_result.error_message}" + ) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid SSH key: {validation_result.error_message}" + detail=f"Invalid SSH key: {validation_result.error_message}", ) - + # Log warnings if any if validation_result.warnings: - logger.warning(f"SSH key warnings for system credentials '{credentials.name}': {'; '.join(validation_result.warnings)}") - + logger.warning( + f"SSH key warnings for system credentials '{credentials.name}': {'; '.join(validation_result.warnings)}" + ) + # Log recommendations if validation_result.recommendations: - logger.info(f"SSH key recommendations for system credentials '{credentials.name}': {'; '.join(validation_result.recommendations)}") - + logger.info( + f"SSH key recommendations for system credentials '{credentials.name}': {'; '.join(validation_result.recommendations)}" + ) + # Extract SSH key metadata for storage and display ssh_key_fingerprint = None ssh_key_type = None ssh_key_bits = None ssh_key_comment = None - + if credentials.private_key: - metadata = extract_ssh_key_metadata(credentials.private_key, credentials.private_key_passphrase) - ssh_key_fingerprint = metadata.get('fingerprint') - ssh_key_type = metadata.get('key_type') - ssh_key_bits = int(metadata.get('key_bits')) if metadata.get('key_bits') else None - ssh_key_comment = metadata.get('key_comment') - - if metadata.get('error'): - logger.warning(f"Failed to extract SSH key metadata for '{credentials.name}': {metadata.get('error')}") - + metadata = extract_ssh_key_metadata( + credentials.private_key, credentials.private_key_passphrase + ) + ssh_key_fingerprint = metadata.get("fingerprint") + ssh_key_type = metadata.get("key_type") + ssh_key_bits = int(metadata.get("key_bits")) if metadata.get("key_bits") else None + ssh_key_comment = metadata.get("key_comment") + + if metadata.get("error"): + logger.warning( + f"Failed to extract SSH key metadata for '{credentials.name}': {metadata.get('error')}" + ) + # Encrypt sensitive data encrypted_password = None encrypted_private_key = None encrypted_passphrase = None - + if credentials.password: encrypted_password = encrypt_data(credentials.password.encode()) if credentials.private_key: encrypted_private_key = encrypt_data(credentials.private_key.encode()) if credentials.private_key_passphrase: encrypted_passphrase = encrypt_data(credentials.private_key_passphrase.encode()) - + current_time = datetime.utcnow() - + # Insert credentials - result = db.execute(text(""" + result = db.execute( + text( + """ INSERT INTO system_credentials (name, description, username, auth_method, encrypted_password, encrypted_private_key, private_key_passphrase, ssh_key_fingerprint, @@ -182,30 +204,33 @@ async def create_system_credentials( :ssh_key_type, :ssh_key_bits, :ssh_key_comment, :is_default, :is_active, :created_by, :created_at, :updated_at) RETURNING id - """), { - "name": credentials.name, - "description": credentials.description, - "username": credentials.username, - "auth_method": credentials.auth_method, - "encrypted_password": encrypted_password, - "encrypted_private_key": encrypted_private_key, - "private_key_passphrase": encrypted_passphrase, - "ssh_key_fingerprint": ssh_key_fingerprint, - "ssh_key_type": ssh_key_type, - "ssh_key_bits": ssh_key_bits, - "ssh_key_comment": ssh_key_comment, - "is_default": credentials.is_default, - "is_active": True, - "created_by": current_user.get('id'), - "created_at": current_time, - "updated_at": current_time - }) - + """ + ), + { + "name": credentials.name, + "description": credentials.description, + "username": credentials.username, + "auth_method": credentials.auth_method, + "encrypted_password": encrypted_password, + "encrypted_private_key": encrypted_private_key, + "private_key_passphrase": encrypted_passphrase, + "ssh_key_fingerprint": ssh_key_fingerprint, + "ssh_key_type": ssh_key_type, + "ssh_key_bits": ssh_key_bits, + "ssh_key_comment": ssh_key_comment, + "is_default": credentials.is_default, + "is_active": True, + "created_by": current_user.get("id"), + "created_at": current_time, + "updated_at": current_time, + }, + ) + credential_id = result.fetchone().id db.commit() - + logger.info(f"Created system credentials '{credentials.name}' (ID: {credential_id})") - + return SystemCredentialsResponse( id=credential_id, name=credentials.name, @@ -219,9 +244,9 @@ async def create_system_credentials( ssh_key_fingerprint=ssh_key_fingerprint, ssh_key_type=ssh_key_type, ssh_key_bits=ssh_key_bits, - ssh_key_comment=ssh_key_comment + ssh_key_comment=ssh_key_comment, ) - + except HTTPException: raise except Exception as e: @@ -232,44 +257,47 @@ async def create_system_credentials( @router.get("/credentials/default") async def get_default_credentials( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get default system credentials for internal use""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, name, username, auth_method, encrypted_password, encrypted_private_key, private_key_passphrase FROM system_credentials WHERE is_default = true AND is_active = true LIMIT 1 - """)) - + """ + ) + ) + row = result.fetchone() if not row: return None - + # Decrypt credentials for internal use password = None private_key = None passphrase = None - + if row.encrypted_password: encrypted_pw = row.encrypted_password if isinstance(encrypted_pw, memoryview): - encrypted_pw = encrypted_pw.tobytes().decode('utf-8') + encrypted_pw = encrypted_pw.tobytes().decode("utf-8") password = decrypt_data(encrypted_pw).decode() if row.encrypted_private_key: encrypted_key = row.encrypted_private_key if isinstance(encrypted_key, memoryview): - encrypted_key = encrypted_key.tobytes().decode('utf-8') + encrypted_key = encrypted_key.tobytes().decode("utf-8") private_key = decrypt_data(encrypted_key).decode() if row.private_key_passphrase: encrypted_phrase = row.private_key_passphrase if isinstance(encrypted_phrase, memoryview): - encrypted_phrase = encrypted_phrase.tobytes().decode('utf-8') + encrypted_phrase = encrypted_phrase.tobytes().decode("utf-8") passphrase = decrypt_data(encrypted_phrase).decode() - + return { "id": row.id, "name": row.name, @@ -277,9 +305,9 @@ async def get_default_credentials( "auth_method": row.auth_method, "password": password, "private_key": private_key, - "private_key_passphrase": passphrase + "private_key_passphrase": passphrase, } - + except Exception as e: logger.error(f"Error getting default credentials: {e}") return None @@ -291,65 +319,84 @@ async def update_system_credentials( credential_id: int, credentials: SystemCredentialsUpdate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Update system credentials (admin only)""" try: - + # Check if credentials exist - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id FROM system_credentials WHERE id = :id - """), {"id": credential_id}) - + """ + ), + {"id": credential_id}, + ) + if not result.fetchone(): raise HTTPException(status_code=404, detail="Credentials not found") - + # If setting as default, unset other defaults if credentials.is_default: - db.execute(text(""" + db.execute( + text( + """ UPDATE system_credentials SET is_default = false WHERE is_default = true - """)) - + """ + ) + ) + # Validate SSH key if provided if credentials.private_key and credentials.auth_method in ["ssh_key", "both"]: logger.info(f"Validating SSH key for system credentials update (ID: {credential_id})") validation_result = validate_ssh_key(credentials.private_key) - + if not validation_result.is_valid: - logger.error(f"SSH key validation failed for system credentials update (ID: {credential_id}): {validation_result.error_message}") + logger.error( + f"SSH key validation failed for system credentials update (ID: {credential_id}): {validation_result.error_message}" + ) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid SSH key: {validation_result.error_message}" + detail=f"Invalid SSH key: {validation_result.error_message}", ) - + # Log warnings if any if validation_result.warnings: - logger.warning(f"SSH key warnings for system credentials update (ID: {credential_id}): {'; '.join(validation_result.warnings)}") - + logger.warning( + f"SSH key warnings for system credentials update (ID: {credential_id}): {'; '.join(validation_result.warnings)}" + ) + # Log recommendations if validation_result.recommendations: - logger.info(f"SSH key recommendations for system credentials update (ID: {credential_id}): {'; '.join(validation_result.recommendations)}") - + logger.info( + f"SSH key recommendations for system credentials update (ID: {credential_id}): {'; '.join(validation_result.recommendations)}" + ) + # Extract SSH key metadata if private key is being updated ssh_key_fingerprint = None ssh_key_type = None ssh_key_bits = None ssh_key_comment = None - + if credentials.private_key: - metadata = extract_ssh_key_metadata(credentials.private_key, credentials.private_key_passphrase) - ssh_key_fingerprint = metadata.get('fingerprint') - ssh_key_type = metadata.get('key_type') - ssh_key_bits = int(metadata.get('key_bits')) if metadata.get('key_bits') else None - ssh_key_comment = metadata.get('key_comment') - - if metadata.get('error'): - logger.warning(f"Failed to extract SSH key metadata for update (ID: {credential_id}): {metadata.get('error')}") - + metadata = extract_ssh_key_metadata( + credentials.private_key, credentials.private_key_passphrase + ) + ssh_key_fingerprint = metadata.get("fingerprint") + ssh_key_type = metadata.get("key_type") + ssh_key_bits = int(metadata.get("key_bits")) if metadata.get("key_bits") else None + ssh_key_comment = metadata.get("key_comment") + + if metadata.get("error"): + logger.warning( + f"Failed to extract SSH key metadata for update (ID: {credential_id}): {metadata.get('error')}" + ) + # Build update query dynamically updates = [] params = {"id": credential_id, "updated_at": datetime.utcnow()} - + if credentials.name is not None: updates.append("name = :name") params["name"] = credentials.name @@ -368,14 +415,18 @@ async def update_system_credentials( if credentials.is_active is not None: updates.append("is_active = :is_active") params["is_active"] = credentials.is_active - + # Handle encrypted fields if credentials.password is not None: updates.append("encrypted_password = :encrypted_password") - params["encrypted_password"] = encrypt_data(credentials.password.encode()) if credentials.password else None + params["encrypted_password"] = ( + encrypt_data(credentials.password.encode()) if credentials.password else None + ) if credentials.private_key is not None: updates.append("encrypted_private_key = :encrypted_private_key") - params["encrypted_private_key"] = encrypt_data(credentials.private_key.encode()) if credentials.private_key else None + params["encrypted_private_key"] = ( + encrypt_data(credentials.private_key.encode()) if credentials.private_key else None + ) # Update SSH key metadata when private key changes updates.append("ssh_key_fingerprint = :ssh_key_fingerprint") params["ssh_key_fingerprint"] = ssh_key_fingerprint if credentials.private_key else None @@ -387,23 +438,32 @@ async def update_system_credentials( params["ssh_key_comment"] = ssh_key_comment if credentials.private_key else None if credentials.private_key_passphrase is not None: updates.append("private_key_passphrase = :private_key_passphrase") - params["private_key_passphrase"] = encrypt_data(credentials.private_key_passphrase.encode()) if credentials.private_key_passphrase else None - + params["private_key_passphrase"] = ( + encrypt_data(credentials.private_key_passphrase.encode()) + if credentials.private_key_passphrase + else None + ) + if updates: updates.append("updated_at = :updated_at") # Security Fix: Use safe string concatenation instead of f-string query = "UPDATE system_credentials SET " + ", ".join(updates) + " WHERE id = :id" db.execute(text(query), params) db.commit() - + # Return updated credentials - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, name, description, username, auth_method, is_default, is_active, created_at, updated_at, ssh_key_fingerprint, ssh_key_type, ssh_key_bits, ssh_key_comment FROM system_credentials WHERE id = :id - """), {"id": credential_id}) - + """ + ), + {"id": credential_id}, + ) + row = result.fetchone() return SystemCredentialsResponse( id=row.id, @@ -418,9 +478,9 @@ async def update_system_credentials( ssh_key_fingerprint=row.ssh_key_fingerprint, ssh_key_type=row.ssh_key_type, ssh_key_bits=row.ssh_key_bits, - ssh_key_comment=row.ssh_key_comment + ssh_key_comment=row.ssh_key_comment, ) - + except HTTPException: raise except Exception as e: @@ -434,39 +494,55 @@ async def update_system_credentials( async def delete_system_credentials( credential_id: int, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Delete system credentials (admin only)""" try: - + # Check if credentials exist - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, is_default FROM system_credentials WHERE id = :id - """), {"id": credential_id}) - + """ + ), + {"id": credential_id}, + ) + row = result.fetchone() if not row: raise HTTPException(status_code=404, detail="Credentials not found") - + # Prevent deletion of default credentials if it's the only one if row.is_default: - count_result = db.execute(text(""" + count_result = db.execute( + text( + """ SELECT COUNT(*) as count FROM system_credentials WHERE is_active = true - """)) + """ + ) + ) if count_result.fetchone().count <= 1: - raise HTTPException(status_code=400, detail="Cannot delete the last active credential set") - + raise HTTPException( + status_code=400, detail="Cannot delete the last active credential set" + ) + # Soft delete (mark as inactive) - db.execute(text(""" + db.execute( + text( + """ UPDATE system_credentials SET is_active = false, updated_at = :updated_at WHERE id = :id - """), {"id": credential_id, "updated_at": datetime.utcnow()}) - + """ + ), + {"id": credential_id, "updated_at": datetime.utcnow()}, + ) + db.commit() - + logger.info(f"Deleted system credentials ID: {credential_id}") return {"message": "Credentials deleted successfully"} - + except HTTPException: raise except Exception as e: @@ -480,35 +556,42 @@ async def delete_system_credentials( async def delete_ssh_key_from_credentials( credential_id: int, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Delete SSH key from system credentials (admin only)""" try: - + # Check if credentials exist and have SSH key - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, auth_method, ssh_key_fingerprint FROM system_credentials WHERE id = :id AND is_active = true - """), {"id": credential_id}) - + """ + ), + {"id": credential_id}, + ) + row = result.fetchone() if not row: raise HTTPException(status_code=404, detail="Credentials not found") - + if not row.ssh_key_fingerprint: raise HTTPException(status_code=400, detail="No SSH key found to delete") - + # Check if this would leave the credential with no authentication method if row.auth_method == "ssh_key": raise HTTPException( - status_code=400, - detail="Cannot delete SSH key - it's the only authentication method. Update to use password authentication first." + status_code=400, + detail="Cannot delete SSH key - it's the only authentication method. Update to use password authentication first.", ) - + # Remove SSH key and update auth method if necessary new_auth_method = "password" if row.auth_method == "both" else row.auth_method - - db.execute(text(""" + + db.execute( + text( + """ UPDATE system_credentials SET encrypted_private_key = NULL, private_key_passphrase = NULL, @@ -519,17 +602,16 @@ async def delete_ssh_key_from_credentials( auth_method = :auth_method, updated_at = :updated_at WHERE id = :id - """), { - "id": credential_id, - "auth_method": new_auth_method, - "updated_at": datetime.utcnow() - }) - + """ + ), + {"id": credential_id, "auth_method": new_auth_method, "updated_at": datetime.utcnow()}, + ) + db.commit() - + logger.info(f"Deleted SSH key from system credentials ID: {credential_id}") return {"message": "SSH key deleted successfully"} - + except HTTPException: raise except Exception as e: @@ -542,64 +624,76 @@ async def delete_ssh_key_from_credentials( scheduler_instance = None -async def restore_scheduler_state(): +def restore_scheduler_state(): """ Restore scheduler state from database on application startup """ global scheduler_instance - + try: from sqlalchemy.orm import sessionmaker from sqlalchemy import text from ..database import engine - + # Create a session to check scheduler state SessionLocal = sessionmaker(bind=engine) db = SessionLocal() - + try: # Check if scheduler was previously enabled - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT setting_value FROM system_settings WHERE setting_key = 'scheduler_enabled' AND setting_value = 'true' - """)) - + """ + ) + ) + was_enabled = result.fetchone() is not None - + if was_enabled: # Get the interval setting - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT setting_value FROM system_settings WHERE setting_key = 'scheduler_interval_minutes' - """)) + """ + ) + ) row = result.fetchone() interval_minutes = int(row.setting_value) if row else 5 - + # Start the scheduler from apscheduler.schedulers.background import BackgroundScheduler from ..tasks.monitoring_tasks import periodic_host_monitoring import atexit - + scheduler_instance = BackgroundScheduler() scheduler_instance.add_job( func=periodic_host_monitoring, trigger="interval", minutes=interval_minutes, - id='host_monitoring', - name='Monitor host availability' + id="host_monitoring", + name="Monitor host availability", ) scheduler_instance.start() - + # Shut down the scheduler when exiting the app - atexit.register(lambda: scheduler_instance.shutdown() if scheduler_instance else None) - - logger.info(f"Host monitoring scheduler restored (every {interval_minutes} minutes)") + atexit.register( + lambda: scheduler_instance.shutdown() if scheduler_instance else None + ) + + logger.info( + f"Host monitoring scheduler restored (every {interval_minutes} minutes)" + ) else: logger.info("Scheduler was not previously enabled, staying stopped") - + finally: db.close() - + except Exception as e: logger.warning(f"Failed to restore scheduler state: {e}") # Don't raise - this is optional initialization @@ -651,48 +745,55 @@ class AlertSettingsResponse(AlertSettingsBase): @router.get("/scheduler", response_model=SchedulerSettings) async def get_scheduler_settings( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get current scheduler settings""" try: global scheduler_instance - + # Check if scheduler is running is_running = scheduler_instance is not None and scheduler_instance.running - + # Get enabled state from database (fallback to runtime state) enabled_from_db = False try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT setting_value FROM system_settings WHERE setting_key = 'scheduler_enabled' - """)) + """ + ) + ) row = result.fetchone() - enabled_from_db = row and row.setting_value == 'true' + enabled_from_db = row and row.setting_value == "true" except: pass - + # Use database state if available, otherwise fall back to runtime state is_enabled = enabled_from_db if enabled_from_db is not None else is_running - + # Get settings from database or use defaults try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT setting_value FROM system_settings WHERE setting_key = 'scheduler_interval_minutes' - """)) + """ + ) + ) row = result.fetchone() interval_minutes = int(row.setting_value) if row else 5 except: interval_minutes = 5 - + return SchedulerSettings( enabled=is_enabled, interval_minutes=interval_minutes, - status="running" if is_running else "stopped" + status="running" if is_running else "stopped", ) - + except Exception as e: logger.error(f"Error getting scheduler settings: {e}") raise HTTPException(status_code=500, detail="Failed to get scheduler settings") @@ -703,21 +804,23 @@ async def get_scheduler_settings( async def start_scheduler( request: SchedulerStartRequest, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Start the host monitoring scheduler""" try: global scheduler_instance - + # Stop existing scheduler if running if scheduler_instance and scheduler_instance.running: scheduler_instance.shutdown() scheduler_instance = None - + # Save interval setting to database try: # Create system_settings table if it doesn't exist - db.execute(text(""" + db.execute( + text( + """ CREATE TABLE IF NOT EXISTS system_settings ( id SERIAL PRIMARY KEY, setting_key VARCHAR(255) UNIQUE NOT NULL, @@ -725,53 +828,62 @@ async def start_scheduler( created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) - """)) - + """ + ) + ) + # Insert or update the interval setting - db.execute(text(""" + db.execute( + text( + """ INSERT INTO system_settings (setting_key, setting_value, updated_at) VALUES ('scheduler_interval_minutes', :interval, :updated_at) ON CONFLICT (setting_key) DO UPDATE SET setting_value = :interval, updated_at = :updated_at - """), { - "interval": str(request.interval_minutes), - "updated_at": datetime.utcnow() - }) - + """ + ), + {"interval": str(request.interval_minutes), "updated_at": datetime.utcnow()}, + ) + # Save enabled state - db.execute(text(""" + db.execute( + text( + """ INSERT INTO system_settings (setting_key, setting_value, updated_at) VALUES ('scheduler_enabled', 'true', :updated_at) ON CONFLICT (setting_key) DO UPDATE SET setting_value = 'true', updated_at = :updated_at - """), { - "updated_at": datetime.utcnow() - }) + """ + ), + {"updated_at": datetime.utcnow()}, + ) db.commit() except Exception as db_error: logger.warning(f"Failed to save scheduler settings to database: {db_error}") - + # Start new scheduler with custom interval from apscheduler.schedulers.background import BackgroundScheduler from ..tasks.monitoring_tasks import periodic_host_monitoring import atexit - + scheduler_instance = BackgroundScheduler() scheduler_instance.add_job( func=periodic_host_monitoring, trigger="interval", minutes=request.interval_minutes, - id='host_monitoring', - name='Monitor host availability' + id="host_monitoring", + name="Monitor host availability", ) scheduler_instance.start() - + # Shut down the scheduler when exiting the app atexit.register(lambda: scheduler_instance.shutdown() if scheduler_instance else None) - + logger.info(f"Host monitoring scheduler started (every {request.interval_minutes} minutes)") - return {"message": f"Scheduler started successfully (every {request.interval_minutes} minutes)"} - + return { + "message": f"Scheduler started successfully (every {request.interval_minutes} minutes)" + } + except Exception as e: logger.error(f"Error starting scheduler: {e}") raise HTTPException(status_code=500, detail="Failed to start scheduler") @@ -780,27 +892,29 @@ async def start_scheduler( @router.post("/scheduler/stop") @require_permission(Permission.SYSTEM_CONFIG) async def stop_scheduler( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Stop the host monitoring scheduler""" try: global scheduler_instance - + # Save disabled state to database try: - db.execute(text(""" + db.execute( + text( + """ INSERT INTO system_settings (setting_key, setting_value, updated_at) VALUES ('scheduler_enabled', 'false', :updated_at) ON CONFLICT (setting_key) DO UPDATE SET setting_value = 'false', updated_at = :updated_at - """), { - "updated_at": datetime.utcnow() - }) + """ + ), + {"updated_at": datetime.utcnow()}, + ) db.commit() except Exception as db_error: logger.warning(f"Failed to save scheduler disabled state to database: {db_error}") - + if scheduler_instance and scheduler_instance.running: scheduler_instance.shutdown() scheduler_instance = None @@ -808,7 +922,7 @@ async def stop_scheduler( return {"message": "Scheduler stopped successfully"} else: return {"message": "Scheduler was not running"} - + except Exception as e: logger.error(f"Error stopping scheduler: {e}") raise HTTPException(status_code=500, detail="Failed to stop scheduler") @@ -819,59 +933,63 @@ async def stop_scheduler( async def update_scheduler_settings( request: SchedulerUpdateRequest, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Update scheduler settings""" try: global scheduler_instance - + # Save new interval to database try: - db.execute(text(""" + db.execute( + text( + """ INSERT INTO system_settings (setting_key, setting_value, updated_at) VALUES ('scheduler_interval_minutes', :interval, :updated_at) ON CONFLICT (setting_key) DO UPDATE SET setting_value = :interval, updated_at = :updated_at - """), { - "interval": str(request.interval_minutes), - "updated_at": datetime.utcnow() - }) + """ + ), + {"interval": str(request.interval_minutes), "updated_at": datetime.utcnow()}, + ) db.commit() except Exception as db_error: logger.warning(f"Failed to save scheduler settings to database: {db_error}") - + # If scheduler is running, restart it with new interval was_running = scheduler_instance is not None and scheduler_instance.running - + if was_running: # Stop current scheduler scheduler_instance.shutdown() scheduler_instance = None - + # Start with new interval from apscheduler.schedulers.background import BackgroundScheduler from ..tasks.monitoring_tasks import periodic_host_monitoring import atexit - + scheduler_instance = BackgroundScheduler() scheduler_instance.add_job( func=periodic_host_monitoring, trigger="interval", minutes=request.interval_minutes, - id='host_monitoring', - name='Monitor host availability' + id="host_monitoring", + name="Monitor host availability", ) scheduler_instance.start() - + atexit.register(lambda: scheduler_instance.shutdown() if scheduler_instance else None) - logger.info(f"Scheduler restarted with new interval: {request.interval_minutes} minutes") - + logger.info( + f"Scheduler restarted with new interval: {request.interval_minutes} minutes" + ) + return SchedulerSettings( enabled=was_running, interval_minutes=request.interval_minutes, - status="running" if was_running else "stopped" + status="running" if was_running else "stopped", ) - + except Exception as e: logger.error(f"Error updating scheduler settings: {e}") raise HTTPException(status_code=500, detail="Failed to update scheduler settings") @@ -881,37 +999,43 @@ async def update_scheduler_settings( @router.get("/alerts", response_model=List[AlertSettingsResponse]) @require_permission(Permission.SYSTEM_CONFIG) async def list_alert_settings( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """List all alert settings for the current user""" try: - user_id = current_user.get('id') - result = db.execute(text(""" + user_id = current_user.get("id") + result = db.execute( + text( + """ SELECT id, user_id, alert_type, enabled, email_enabled, email_addresses, webhook_url, webhook_enabled, created_at, updated_at FROM alert_settings WHERE user_id = :user_id ORDER BY alert_type - """), {"user_id": user_id}) - + """ + ), + {"user_id": user_id}, + ) + settings = [] for row in result: - settings.append(AlertSettingsResponse( - id=row.id, - user_id=row.user_id, - alert_type=row.alert_type, - enabled=row.enabled, - email_enabled=row.email_enabled, - email_addresses=row.email_addresses or [], - webhook_url=row.webhook_url, - webhook_enabled=row.webhook_enabled, - created_at=row.created_at.isoformat(), - updated_at=row.updated_at.isoformat() - )) - + settings.append( + AlertSettingsResponse( + id=row.id, + user_id=row.user_id, + alert_type=row.alert_type, + enabled=row.enabled, + email_enabled=row.email_enabled, + email_addresses=row.email_addresses or [], + webhook_url=row.webhook_url, + webhook_enabled=row.webhook_enabled, + created_at=row.created_at.isoformat(), + updated_at=row.updated_at.isoformat(), + ) + ) + return settings - + except Exception as e: logger.error(f"Error listing alert settings: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve alert settings") @@ -922,15 +1046,17 @@ async def list_alert_settings( async def create_alert_settings( alert_settings: AlertSettingsCreate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Create new alert settings""" try: - user_id = current_user.get('id') + user_id = current_user.get("id") current_time = datetime.utcnow() - + # Insert or update alert settings - result = db.execute(text(""" + result = db.execute( + text( + """ INSERT INTO alert_settings (user_id, alert_type, enabled, email_enabled, email_addresses, webhook_url, webhook_enabled, created_at, updated_at) @@ -945,23 +1071,28 @@ async def create_alert_settings( webhook_enabled = :webhook_enabled, updated_at = :updated_at RETURNING id - """), { - "user_id": user_id, - "alert_type": alert_settings.alert_type, - "enabled": alert_settings.enabled, - "email_enabled": alert_settings.email_enabled, - "email_addresses": alert_settings.email_addresses, - "webhook_url": alert_settings.webhook_url, - "webhook_enabled": alert_settings.webhook_enabled, - "created_at": current_time, - "updated_at": current_time - }) - + """ + ), + { + "user_id": user_id, + "alert_type": alert_settings.alert_type, + "enabled": alert_settings.enabled, + "email_enabled": alert_settings.email_enabled, + "email_addresses": alert_settings.email_addresses, + "webhook_url": alert_settings.webhook_url, + "webhook_enabled": alert_settings.webhook_enabled, + "created_at": current_time, + "updated_at": current_time, + }, + ) + setting_id = result.fetchone().id db.commit() - - logger.info(f"Created/updated alert settings for {alert_settings.alert_type} (ID: {setting_id})") - + + logger.info( + f"Created/updated alert settings for {alert_settings.alert_type} (ID: {setting_id})" + ) + return AlertSettingsResponse( id=setting_id, user_id=user_id, @@ -972,9 +1103,9 @@ async def create_alert_settings( webhook_url=alert_settings.webhook_url, webhook_enabled=alert_settings.webhook_enabled, created_at=current_time.isoformat(), - updated_at=current_time.isoformat() + updated_at=current_time.isoformat(), ) - + except Exception as e: logger.error(f"Error creating alert settings: {e}") db.rollback() @@ -987,25 +1118,30 @@ async def update_alert_settings( alert_id: int, alert_settings: AlertSettingsUpdate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Update existing alert settings""" try: - user_id = current_user.get('id') - + user_id = current_user.get("id") + # Check if alert settings exist and belong to user - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id FROM alert_settings WHERE id = :id AND user_id = :user_id - """), {"id": alert_id, "user_id": user_id}) - + """ + ), + {"id": alert_id, "user_id": user_id}, + ) + if not result.fetchone(): raise HTTPException(status_code=404, detail="Alert settings not found") - + # Build update query dynamically updates = [] params = {"id": alert_id, "updated_at": datetime.utcnow()} - + if alert_settings.enabled is not None: updates.append("enabled = :enabled") params["enabled"] = alert_settings.enabled @@ -1021,21 +1157,26 @@ async def update_alert_settings( if alert_settings.webhook_enabled is not None: updates.append("webhook_enabled = :webhook_enabled") params["webhook_enabled"] = alert_settings.webhook_enabled - + if updates: updates.append("updated_at = :updated_at") # Security Fix: Use safe string concatenation instead of f-string query = "UPDATE alert_settings SET " + ", ".join(updates) + " WHERE id = :id" db.execute(text(query), params) db.commit() - + # Return updated settings - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, user_id, alert_type, enabled, email_enabled, email_addresses, webhook_url, webhook_enabled, created_at, updated_at FROM alert_settings WHERE id = :id - """), {"id": alert_id}) - + """ + ), + {"id": alert_id}, + ) + row = result.fetchone() return AlertSettingsResponse( id=row.id, @@ -1047,12 +1188,12 @@ async def update_alert_settings( webhook_url=row.webhook_url, webhook_enabled=row.webhook_enabled, created_at=row.created_at.isoformat(), - updated_at=row.updated_at.isoformat() + updated_at=row.updated_at.isoformat(), ) - + except HTTPException: raise except Exception as e: logger.error(f"Error updating alert settings: {e}") db.rollback() - raise HTTPException(status_code=500, detail="Failed to update alert settings") \ No newline at end of file + raise HTTPException(status_code=500, detail="Failed to update alert settings") diff --git a/backend/app/routes/system_settings_unified.py b/backend/app/routes/system_settings_unified.py index d38d0933..2a7317b5 100644 --- a/backend/app/routes/system_settings_unified.py +++ b/backend/app/routes/system_settings_unified.py @@ -2,6 +2,7 @@ System Settings API Routes - Unified Credentials Version Updated to use the unified credentials system while maintaining API compatibility """ + from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from sqlalchemy import text @@ -15,7 +16,13 @@ from ..database import get_db from ..auth import get_current_user from ..rbac import require_permission, Permission -from ..services.auth_service import get_auth_service, CredentialData, CredentialMetadata, CredentialScope, AuthMethod +from ..services.auth_service import ( + get_auth_service, + CredentialData, + CredentialMetadata, + CredentialScope, + AuthMethod, +) from ..services.ssh_utils import validate_ssh_key, format_validation_message from ..services.ssh_key_service import extract_ssh_key_metadata from ..tasks.monitoring_tasks import setup_host_monitoring_scheduler @@ -67,20 +74,24 @@ class SystemCredentialsResponse(SystemCredentialsBase): def uuid_to_int(uuid_str) -> int: """Convert UUID to deterministic integer for frontend compatibility""" # Convert UUID object to string if needed - if hasattr(uuid_str, '__str__'): + if hasattr(uuid_str, "__str__"): uuid_str = str(uuid_str) # Use first 8 bytes of SHA256 hash as integer hash_bytes = hashlib.sha256(uuid_str.encode()).digest()[:8] - return int.from_bytes(hash_bytes, byteorder='big', signed=False) % (2**31) # Keep positive + return int.from_bytes(hash_bytes, byteorder="big", signed=False) % (2**31) # Keep positive def find_uuid_by_int(db: Session, target_int: int) -> Optional[str]: """Find UUID by matching the generated integer ID""" - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id FROM unified_credentials WHERE scope = 'system' AND is_active = true - """)) - + """ + ) + ) + for row in result: if uuid_to_int(row.id) == target_int: return row.id @@ -90,45 +101,46 @@ def find_uuid_by_int(db: Session, target_int: int) -> Optional[str]: @router.get("/credentials", response_model=List[SystemCredentialsResponse]) @require_permission(Permission.SYSTEM_CREDENTIALS) async def list_system_credentials( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """List all system credentials using unified credentials system""" try: auth_service = get_auth_service(db) - + # Get all system-scoped credentials from unified table credentials_list = auth_service.list_credentials(scope=CredentialScope.SYSTEM) - + response_list = [] for cred in credentials_list: # Convert UUID to integer for frontend compatibility external_id = uuid_to_int(cred["id"]) - - response_list.append(SystemCredentialsResponse( - id=external_id, - name=cred["name"], - description=cred["description"], - username=cred["username"], - auth_method=cred["auth_method"], - is_default=cred["is_default"], - is_active=True, # Only active credentials are returned by list_credentials - created_at=cred["created_at"], - updated_at=cred["updated_at"], - ssh_key_fingerprint=cred["ssh_key_fingerprint"], - ssh_key_type=cred["ssh_key_type"], - ssh_key_bits=cred["ssh_key_bits"], - ssh_key_comment=cred["ssh_key_comment"] - )) - + + response_list.append( + SystemCredentialsResponse( + id=external_id, + name=cred["name"], + description=cred["description"], + username=cred["username"], + auth_method=cred["auth_method"], + is_default=cred["is_default"], + is_active=True, # Only active credentials are returned by list_credentials + created_at=cred["created_at"], + updated_at=cred["updated_at"], + ssh_key_fingerprint=cred["ssh_key_fingerprint"], + ssh_key_type=cred["ssh_key_type"], + ssh_key_bits=cred["ssh_key_bits"], + ssh_key_comment=cred["ssh_key_comment"], + ) + ) + logger.info(f"Retrieved {len(response_list)} unified system credentials") return response_list - + except Exception as e: logger.error(f"Failed to list system credentials: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve system credentials" + detail="Failed to retrieve system credentials", ) @@ -137,7 +149,7 @@ async def list_system_credentials( async def create_system_credential( credential: SystemCredentialsCreate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Create new system credential using unified credentials system""" try: @@ -146,31 +158,31 @@ async def create_system_credential( if credential.auth_method not in valid_methods: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid auth method. Must be one of: {valid_methods}" + detail=f"Invalid auth method. Must be one of: {valid_methods}", ) - + # Validate required fields based on auth method if credential.auth_method in ["password", "both"] and not credential.password: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Password is required for password authentication" + detail="Password is required for password authentication", ) - + if credential.auth_method in ["ssh_key", "both"] and not credential.private_key: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Private key is required for SSH key authentication" + detail="Private key is required for SSH key authentication", ) - + # Validate SSH key if provided if credential.private_key: validation_result = validate_ssh_key(credential.private_key) if not validation_result.is_valid: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid SSH key: {validation_result.error_message}" + detail=f"Invalid SSH key: {validation_result.error_message}", ) - + # Create credential data credential_data = CredentialData( username=credential.username, @@ -178,9 +190,9 @@ async def create_system_credential( password=credential.password, private_key=credential.private_key, private_key_passphrase=credential.private_key_passphrase, - source="system_settings_api" + source="system_settings_api", ) - + # Create metadata metadata = CredentialMetadata( name=credential.name, @@ -188,36 +200,36 @@ async def create_system_credential( scope=CredentialScope.SYSTEM, target_id=None, is_default=credential.is_default, - is_active=True + is_active=True, ) - + # Store using unified credentials service auth_service = get_auth_service(db) # Convert integer user ID to UUID format for unified credentials user_uuid = f"00000000-0000-0000-0000-{current_user['id']:012d}" logger.info(f"Converting user ID {current_user['id']} to UUID format: {user_uuid}") credential_id = auth_service.store_credential( - credential_data=credential_data, - metadata=metadata, - created_by=user_uuid + credential_data=credential_data, metadata=metadata, created_by=user_uuid ) - + # Get the created credential for response logger.info(f"Looking for created credential with ID: {credential_id}") created_cred_list = auth_service.list_credentials(scope=CredentialScope.SYSTEM) logger.info(f"Found {len(created_cred_list)} system credentials") created_cred = next((c for c in created_cred_list if c["id"] == credential_id), None) - + if not created_cred: - logger.error(f"Failed to find credential {credential_id} in list of {len(created_cred_list)} credentials") + logger.error( + f"Failed to find credential {credential_id} in list of {len(created_cred_list)} credentials" + ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve created credential" + detail="Failed to retrieve created credential", ) - + # Convert to response format external_id = uuid_to_int(created_cred["id"]) - + response = SystemCredentialsResponse( id=external_id, name=created_cred["name"], @@ -231,19 +243,21 @@ async def create_system_credential( ssh_key_fingerprint=created_cred["ssh_key_fingerprint"], ssh_key_type=created_cred["ssh_key_type"], ssh_key_bits=created_cred["ssh_key_bits"], - ssh_key_comment=created_cred["ssh_key_comment"] + ssh_key_comment=created_cred["ssh_key_comment"], + ) + + logger.info( + f"Created system credential '{credential.name}' with unified ID: {credential_id}" ) - - logger.info(f"Created system credential '{credential.name}' with unified ID: {credential_id}") return response - + except HTTPException: raise except Exception as e: logger.error(f"Failed to create system credential: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to create system credential" + detail="Failed to create system credential", ) @@ -252,7 +266,7 @@ async def create_system_credential( async def get_system_credential( credential_id: int, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Get specific system credential by ID""" try: @@ -260,21 +274,19 @@ async def get_system_credential( uuid_id = find_uuid_by_int(db, credential_id) if not uuid_id: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Credential not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Credential not found" ) - + # Get credential using unified service auth_service = get_auth_service(db) credentials_list = auth_service.list_credentials(scope=CredentialScope.SYSTEM) - + credential = next((c for c in credentials_list if c["id"] == uuid_id), None) if not credential: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Credential not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Credential not found" ) - + return SystemCredentialsResponse( id=credential_id, # Use original external ID name=credential["name"], @@ -288,37 +300,36 @@ async def get_system_credential( ssh_key_fingerprint=credential["ssh_key_fingerprint"], ssh_key_type=credential["ssh_key_type"], ssh_key_bits=credential["ssh_key_bits"], - ssh_key_comment=credential["ssh_key_comment"] + ssh_key_comment=credential["ssh_key_comment"], ) - + except HTTPException: raise except Exception as e: logger.error(f"Failed to get system credential {credential_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve system credential" + detail="Failed to retrieve system credential", ) @router.get("/credentials/default", response_model=Optional[SystemCredentialsResponse]) @require_permission(Permission.SYSTEM_CREDENTIALS) async def get_default_system_credential( - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get default system credential""" try: auth_service = get_auth_service(db) credentials_list = auth_service.list_credentials(scope=CredentialScope.SYSTEM) - + # Find default credential default_cred = next((c for c in credentials_list if c["is_default"]), None) if not default_cred: return None - + external_id = uuid_to_int(default_cred["id"]) - + return SystemCredentialsResponse( id=external_id, name=default_cred["name"], @@ -332,14 +343,14 @@ async def get_default_system_credential( ssh_key_fingerprint=default_cred["ssh_key_fingerprint"], ssh_key_type=default_cred["ssh_key_type"], ssh_key_bits=default_cred["ssh_key_bits"], - ssh_key_comment=default_cred["ssh_key_comment"] + ssh_key_comment=default_cred["ssh_key_comment"], ) - + except Exception as e: logger.error(f"Failed to get default system credential: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve default system credential" + detail="Failed to retrieve default system credential", ) @@ -349,7 +360,7 @@ async def update_system_credential( credential_id: int, credential_update: SystemCredentialsUpdate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Update system credential (Note: Currently creates new due to unified credentials architecture)""" try: @@ -357,62 +368,64 @@ async def update_system_credential( uuid_id = find_uuid_by_int(db, credential_id) if not uuid_id: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Credential not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Credential not found" ) - + auth_service = get_auth_service(db) credentials_list = auth_service.list_credentials(scope=CredentialScope.SYSTEM) - + # Get existing credential existing_cred = next((c for c in credentials_list if c["id"] == uuid_id), None) if not existing_cred: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Credential not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Credential not found" ) - + # For unified credentials, we need to create a new credential and deactivate old one # This is because the unified system doesn't support in-place updates yet - + # Merge updates with existing values updated_name = credential_update.name or existing_cred["name"] updated_description = credential_update.description or existing_cred["description"] updated_username = credential_update.username or existing_cred["username"] updated_auth_method = credential_update.auth_method or existing_cred["auth_method"] - updated_is_default = credential_update.is_default if credential_update.is_default is not None else existing_cred["is_default"] - + updated_is_default = ( + credential_update.is_default + if credential_update.is_default is not None + else existing_cred["is_default"] + ) + # Validate auth method valid_methods = ["ssh_key", "password", "both"] if updated_auth_method not in valid_methods: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid auth method. Must be one of: {valid_methods}" + detail=f"Invalid auth method. Must be one of: {valid_methods}", ) - + # For updates, we need the credential data (this would normally be stored encrypted) # Since we can't easily decrypt existing credentials, require new credential data for updates if not credential_update.password and updated_auth_method in ["password", "both"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Password is required when updating password-based authentication" + detail="Password is required when updating password-based authentication", ) - + if not credential_update.private_key and updated_auth_method in ["ssh_key", "both"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Private key is required when updating SSH key authentication" + detail="Private key is required when updating SSH key authentication", ) - + # Validate SSH key if provided if credential_update.private_key: validation_result = validate_ssh_key(credential_update.private_key) if not validation_result.is_valid: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Invalid SSH key: {validation_result.error_message}" + detail=f"Invalid SSH key: {validation_result.error_message}", ) - + # Create new credential data credential_data = CredentialData( username=updated_username, @@ -420,9 +433,9 @@ async def update_system_credential( password=credential_update.password, private_key=credential_update.private_key, private_key_passphrase=credential_update.private_key_passphrase, - source="system_settings_api_update" + source="system_settings_api_update", ) - + # Create new metadata metadata = CredentialMetadata( name=updated_name, @@ -430,31 +443,29 @@ async def update_system_credential( scope=CredentialScope.SYSTEM, target_id=None, is_default=updated_is_default, - is_active=True + is_active=True, ) - + # Delete old credential auth_service.delete_credential(uuid_id) - + # Store new credential # Convert integer user ID to UUID format for unified credentials user_uuid = f"00000000-0000-0000-0000-{current_user['id']:012d}" new_credential_id = auth_service.store_credential( - credential_data=credential_data, - metadata=metadata, - created_by=user_uuid + credential_data=credential_data, metadata=metadata, created_by=user_uuid ) - + # Get the created credential for response updated_cred_list = auth_service.list_credentials(scope=CredentialScope.SYSTEM) updated_cred = next((c for c in updated_cred_list if c["id"] == new_credential_id), None) - + if not updated_cred: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve updated credential" + detail="Failed to retrieve updated credential", ) - + # Use same external ID for consistency response = SystemCredentialsResponse( id=credential_id, # Keep same external ID @@ -469,19 +480,21 @@ async def update_system_credential( ssh_key_fingerprint=updated_cred["ssh_key_fingerprint"], ssh_key_type=updated_cred["ssh_key_type"], ssh_key_bits=updated_cred["ssh_key_bits"], - ssh_key_comment=updated_cred["ssh_key_comment"] + ssh_key_comment=updated_cred["ssh_key_comment"], + ) + + logger.info( + f"Updated system credential '{updated_name}' with new unified ID: {new_credential_id}" ) - - logger.info(f"Updated system credential '{updated_name}' with new unified ID: {new_credential_id}") return response - + except HTTPException: raise except Exception as e: logger.error(f"Failed to update system credential {credential_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to update system credential" + detail="Failed to update system credential", ) @@ -490,7 +503,7 @@ async def update_system_credential( async def delete_system_credential( credential_id: int, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Delete system credential""" try: @@ -498,30 +511,31 @@ async def delete_system_credential( uuid_id = find_uuid_by_int(db, credential_id) if not uuid_id: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Credential not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Credential not found" ) - + # Delete using unified service auth_service = get_auth_service(db) success = auth_service.delete_credential(uuid_id) - + if not success: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Credential not found or already deleted" + detail="Credential not found or already deleted", ) - - logger.info(f"Deleted system credential with external ID: {credential_id} (unified ID: {uuid_id})") + + logger.info( + f"Deleted system credential with external ID: {credential_id} (unified ID: {uuid_id})" + ) return {"message": "Credential deleted successfully"} - + except HTTPException: raise except Exception as e: logger.error(f"Failed to delete system credential {credential_id}: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to delete system credential" + detail="Failed to delete system credential", ) @@ -557,42 +571,44 @@ def get_scheduler(): @router.get("/scheduler", response_model=SchedulerStatus) @require_permission(Permission.SYSTEM_MAINTENANCE) -async def get_scheduler_status( - current_user: dict = Depends(get_current_user) -): +async def get_scheduler_status(current_user: dict = Depends(get_current_user)): """Get current scheduler status""" try: scheduler = get_scheduler() - + if scheduler is None: return SchedulerStatus( enabled=False, interval_minutes=_scheduler_interval, status="error", jobs=[], - uptime=None + uptime=None, ) - + # Get scheduler status if scheduler.running: jobs_info = [] try: for job in scheduler.get_jobs(): - jobs_info.append({ - "id": job.id, - "name": job.name, - "next_run": job.next_run_time.isoformat() if job.next_run_time else None, - "trigger": str(job.trigger) - }) + jobs_info.append( + { + "id": job.id, + "name": job.name, + "next_run": ( + job.next_run_time.isoformat() if job.next_run_time else None + ), + "trigger": str(job.trigger), + } + ) except Exception as e: logger.warning(f"Failed to get job info: {e}") - + return SchedulerStatus( enabled=True, interval_minutes=_scheduler_interval, status="running", jobs=jobs_info, - uptime="Running" + uptime="Running", ) else: return SchedulerStatus( @@ -600,9 +616,9 @@ async def get_scheduler_status( interval_minutes=_scheduler_interval, status="stopped", jobs=[], - uptime=None + uptime=None, ) - + except Exception as e: logger.error(f"Failed to get scheduler status: {e}") return SchedulerStatus( @@ -610,298 +626,329 @@ async def get_scheduler_status( interval_minutes=_scheduler_interval, status="error", jobs=[], - uptime=None + uptime=None, ) @router.post("/scheduler/start") @require_permission(Permission.SYSTEM_MAINTENANCE) async def start_scheduler( - request: SchedulerStartRequest, - current_user: dict = Depends(get_current_user) + request: SchedulerStartRequest, current_user: dict = Depends(get_current_user) ): """Start the monitoring scheduler""" try: global _scheduler, _scheduler_interval _scheduler_interval = request.interval_minutes scheduler = get_scheduler() - + if scheduler is None: # Try to create a new scheduler _scheduler = setup_host_monitoring_scheduler() scheduler = _scheduler - + if scheduler is None: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to create scheduler (APScheduler not available)" + detail="Failed to create scheduler (APScheduler not available)", ) - + if not scheduler.running: scheduler.start() - + # Configure the monitoring job with the requested interval from ..tasks.monitoring_tasks import periodic_host_monitoring - + # Remove any existing job first for job in scheduler.get_jobs(): if job.id == "host_monitoring": scheduler.remove_job(job.id) - + # Add the job with the specified interval scheduler.add_job( periodic_host_monitoring, - 'interval', + "interval", minutes=_scheduler_interval, - id='host_monitoring', - name='Host Monitoring Task', - replace_existing=True + id="host_monitoring", + name="Host Monitoring Task", + replace_existing=True, ) - + # Update database with start time and enabled status try: from ..database import get_db + db = next(get_db()) - db.execute(text(""" + db.execute( + text( + """ UPDATE scheduler_config SET enabled = TRUE, last_started = CURRENT_TIMESTAMP, interval_minutes = :interval, updated_at = CURRENT_TIMESTAMP WHERE service_name = 'host_monitoring' - """), {"interval": _scheduler_interval}) + """ + ), + {"interval": _scheduler_interval}, + ) db.commit() db.close() except Exception as db_error: logger.warning(f"Failed to update scheduler database state: {db_error}") - - logger.info(f"Host monitoring scheduler started with {_scheduler_interval} minute interval by user {current_user.get('username', 'unknown')}") - + + logger.info( + f"Host monitoring scheduler started with {_scheduler_interval} minute interval by user {current_user.get('username', 'unknown')}" + ) + return { "message": "Scheduler started successfully", "status": "running", - "interval_minutes": _scheduler_interval + "interval_minutes": _scheduler_interval, } else: - return { - "message": "Scheduler is already running", - "status": "running" - } - + return {"message": "Scheduler is already running", "status": "running"} + except Exception as e: logger.error(f"Failed to start scheduler: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to start scheduler: {str(e)}" + detail=f"Failed to start scheduler: {str(e)}", ) @router.post("/scheduler/stop") @require_permission(Permission.SYSTEM_MAINTENANCE) -async def stop_scheduler( - current_user: dict = Depends(get_current_user) -): +async def stop_scheduler(current_user: dict = Depends(get_current_user)): """Stop the monitoring scheduler""" try: scheduler = get_scheduler() - + if scheduler is None: - return { - "message": "Scheduler is not initialized", - "status": "stopped" - } - + return {"message": "Scheduler is not initialized", "status": "stopped"} + if scheduler.running: scheduler.pause() - + # Update database with stop time and disabled status try: from ..database import get_db + db = next(get_db()) - db.execute(text(""" + db.execute( + text( + """ UPDATE scheduler_config SET enabled = FALSE, last_stopped = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE service_name = 'host_monitoring' - """)) + """ + ) + ) db.commit() db.close() except Exception as db_error: logger.warning(f"Failed to update scheduler database state: {db_error}") - - logger.info(f"Host monitoring scheduler stopped by user {current_user.get('username', 'unknown')}") - - return { - "message": "Scheduler stopped successfully", - "status": "stopped" - } + + logger.info( + f"Host monitoring scheduler stopped by user {current_user.get('username', 'unknown')}" + ) + + return {"message": "Scheduler stopped successfully", "status": "stopped"} else: - return { - "message": "Scheduler is already stopped", - "status": "stopped" - } - + return {"message": "Scheduler is already stopped", "status": "stopped"} + except Exception as e: logger.error(f"Failed to stop scheduler: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to stop scheduler: {str(e)}" + detail=f"Failed to stop scheduler: {str(e)}", ) @router.put("/scheduler") @require_permission(Permission.SYSTEM_MAINTENANCE) async def update_scheduler( - request: SchedulerUpdateRequest, - current_user: dict = Depends(get_current_user) + request: SchedulerUpdateRequest, current_user: dict = Depends(get_current_user) ): """Update scheduler settings""" try: global _scheduler_interval _scheduler_interval = request.interval_minutes - + scheduler = get_scheduler() - + # Update database with new interval try: from ..database import get_db + db = next(get_db()) - db.execute(text(""" + db.execute( + text( + """ UPDATE scheduler_config SET interval_minutes = :interval, updated_at = CURRENT_TIMESTAMP WHERE service_name = 'host_monitoring' - """), {"interval": _scheduler_interval}) + """ + ), + {"interval": _scheduler_interval}, + ) db.commit() db.close() except Exception as db_error: logger.warning(f"Failed to update scheduler database interval: {db_error}") - + # If scheduler is running, we need to reschedule the job with the new interval if scheduler and scheduler.running: # Remove existing jobs for job in scheduler.get_jobs(): if job.id == "host_monitoring": scheduler.remove_job(job.id) - + # Add new job with updated interval from ..tasks.monitoring_tasks import periodic_host_monitoring + scheduler.add_job( periodic_host_monitoring, - 'interval', + "interval", minutes=_scheduler_interval, - id='host_monitoring', - name='Host Monitoring Task', - replace_existing=True + id="host_monitoring", + name="Host Monitoring Task", + replace_existing=True, + ) + + logger.info( + f"Scheduler interval updated to {_scheduler_interval} minutes by user {current_user.get('username', 'unknown')}" ) - - logger.info(f"Scheduler interval updated to {_scheduler_interval} minutes by user {current_user.get('username', 'unknown')}") - + return { "message": f"Scheduler interval updated to {_scheduler_interval} minutes", "interval_minutes": _scheduler_interval, - "status": "running" if scheduler and scheduler.running else "stopped" + "status": "running" if scheduler and scheduler.running else "stopped", } - + except Exception as e: logger.error(f"Failed to update scheduler: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to update scheduler: {str(e)}" + detail=f"Failed to update scheduler: {str(e)}", ) -async def restore_scheduler_state(): +def restore_scheduler_state(): """Restore scheduler state from database on startup""" logger.info("restore_scheduler_state() function called") try: global _scheduler, _scheduler_interval - + # Get database session from ..database import get_db + db = next(get_db()) - + try: # Read scheduler configuration from database - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT enabled, interval_minutes, auto_start FROM scheduler_config WHERE service_name = 'host_monitoring' - """)) - + """ + ) + ) + config = result.fetchone() logger.info(f"Database config found: {config if config else 'None'}") - + if config: _scheduler_interval = config.interval_minutes - logger.info(f"Setting global scheduler interval to {_scheduler_interval} minutes from database") - + logger.info( + f"Setting global scheduler interval to {_scheduler_interval} minutes from database" + ) + if config.enabled and config.auto_start: - logger.info(f"Auto-start enabled, initializing scheduler with {_scheduler_interval} minute interval") + logger.info( + f"Auto-start enabled, initializing scheduler with {_scheduler_interval} minute interval" + ) # Auto-start scheduler with database configuration scheduler = get_scheduler() if scheduler and not scheduler.running: scheduler.start() logger.info("Scheduler started successfully") - + # Configure the monitoring job with saved interval from ..tasks.monitoring_tasks import periodic_host_monitoring - + # Remove any existing job first (including the hardcoded one from setup) existing_jobs = scheduler.get_jobs() logger.info(f"Found {len(existing_jobs)} existing jobs to remove") for job in existing_jobs: logger.info(f"Removing existing job: {job.id} - {job.name}") scheduler.remove_job(job.id) - + # Add the job with the correct interval from database scheduler.add_job( periodic_host_monitoring, - 'interval', + "interval", minutes=_scheduler_interval, - id='host_monitoring', - name='Host Monitoring Task', - replace_existing=True + id="host_monitoring", + name="Host Monitoring Task", + replace_existing=True, + ) + logger.info( + f"Added new monitoring job with {_scheduler_interval} minute interval" ) - logger.info(f"Added new monitoring job with {_scheduler_interval} minute interval") - + # Update database with start time - db.execute(text(""" + db.execute( + text( + """ UPDATE scheduler_config SET last_started = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE service_name = 'host_monitoring' - """)) + """ + ) + ) db.commit() - - logger.info(f"Host monitoring scheduler auto-started with {_scheduler_interval} minute interval") + + logger.info( + f"Host monitoring scheduler auto-started with {_scheduler_interval} minute interval" + ) else: - logger.info("Scheduler initialized but not auto-started (already running or failed to create)") + logger.info( + "Scheduler initialized but not auto-started (already running or failed to create)" + ) else: logger.info("Scheduler configured but auto-start disabled or not enabled") else: # No configuration found, create default - db.execute(text(""" + db.execute( + text( + """ INSERT INTO scheduler_config ( service_name, enabled, interval_minutes, auto_start ) VALUES ( 'host_monitoring', TRUE, 15, TRUE ) - """)) + """ + ) + ) db.commit() logger.info("Created default scheduler configuration") - + except Exception as db_error: logger.warning(f"Database scheduler config not available, using defaults: {db_error}") # Fall back to basic initialization without auto-start scheduler = get_scheduler() if scheduler: logger.info("Scheduler initialized with defaults (not auto-started)") - + finally: db.close() - + except Exception as e: logger.error(f"Failed to restore scheduler state: {e}") - # Don't raise - scheduler can be started manually \ No newline at end of file + # Don't raise - scheduler can be started manually diff --git a/backend/app/routes/terminal.py b/backend/app/routes/terminal.py index bce54fbe..7f37d457 100644 --- a/backend/app/routes/terminal.py +++ b/backend/app/routes/terminal.py @@ -11,6 +11,7 @@ from ..database import get_db from ..services.terminal_service import terminal_service + # from ..auth import get_current_user # Optional for future authentication logger = logging.getLogger(__name__) @@ -25,28 +26,26 @@ def get_client_ip(request: Request) -> str: if forwarded_for: # Take the first IP if there are multiple return forwarded_for.split(",")[0].strip() - + # Check X-Real-IP header (nginx) real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip.strip() - + # Fallback to direct client IP if request.client and request.client.host: return request.client.host - + return "unknown" @router.websocket("/api/hosts/{host_id}/terminal") async def host_terminal_websocket( - websocket: WebSocket, - host_id: str, - db: Session = Depends(get_db) + websocket: WebSocket, host_id: str, db: Session = Depends(get_db) ): """ WebSocket endpoint for SSH terminal access to a specific host - + Args: websocket: WebSocket connection host_id: UUID of the host to connect to @@ -63,19 +62,16 @@ async def host_terminal_websocket( client_ip = websocket.client.host except Exception: pass - + logger.info(f"Terminal WebSocket connection requested for host {host_id} from {client_ip}") - + # Note: WebSocket connections don't easily support standard HTTP auth middleware # For now, we'll accept connections and rely on network-level security # In production, consider implementing WebSocket-specific auth - + try: await terminal_service.handle_websocket_connection( - websocket=websocket, - host_id=host_id, - db=db, - client_ip=client_ip + websocket=websocket, host_id=host_id, db=db, client_ip=client_ip ) except WebSocketDisconnect: logger.info(f"WebSocket disconnected for host {host_id}") @@ -88,17 +84,14 @@ async def host_terminal_websocket( @router.get("/api/hosts/{host_id}/terminal/status") -async def get_terminal_status( - host_id: str, - db: Session = Depends(get_db) -): +async def get_terminal_status(host_id: str, db: Session = Depends(get_db)): """ Get terminal connection status for a host - + Args: host_id: UUID of the host db: Database session - + Returns: Terminal status information """ @@ -106,37 +99,32 @@ async def get_terminal_status( # Check if host exists using raw SQL query result = db.execute(text("SELECT * FROM hosts WHERE id = :host_id"), {"host_id": host_id}) host_data = result.fetchone() - + if not host_data: return {"error": "Host not found"} - + # Convert row to dict-like object host = { "id": str(host_data.id), "hostname": host_data.hostname, "ip_address": host_data.ip_address, - "auth_method": host_data.auth_method + "auth_method": host_data.auth_method, } - - + # Check for active sessions active_sessions = [ - key for key in terminal_service.active_sessions.keys() - if key.startswith(f"{host_id}_") + key for key in terminal_service.active_sessions.keys() if key.startswith(f"{host_id}_") ] - + return { "host_id": host_id, "hostname": host["hostname"], "ip_address": host["ip_address"], "active_sessions": len(active_sessions), "auth_method": host["auth_method"], - "terminal_available": True + "terminal_available": True, } - + except Exception as e: logger.error(f"Error getting terminal status for host {host_id}: {e}") - return { - "error": "Failed to get terminal status", - "details": str(e) - } \ No newline at end of file + return {"error": "Failed to get terminal status", "details": str(e)} diff --git a/backend/app/routes/users.py b/backend/app/routes/users.py index de2b7320..9f279321 100644 --- a/backend/app/routes/users.py +++ b/backend/app/routes/users.py @@ -2,6 +2,7 @@ User Management API Routes Handles user CRUD operations with role-based access control """ + from fastapi import APIRouter, Depends, HTTPException, status, Query from sqlalchemy.orm import Session from sqlalchemy import text @@ -13,8 +14,12 @@ from ..database import get_db from ..auth import get_current_user, pwd_context from ..rbac import ( - require_permission, require_super_admin, require_admin, - Permission, UserRole, RBACManager + require_permission, + require_super_admin, + require_admin, + Permission, + UserRole, + RBACManager, ) logger = logging.getLogger(__name__) @@ -48,7 +53,7 @@ class UserResponse(UserBase): last_login: Optional[str] = None failed_login_attempts: int locked_until: Optional[str] = None - + class Config: from_attributes = True @@ -73,17 +78,16 @@ class RoleInfo(BaseModel): @router.get("/roles", response_model=List[RoleInfo]) -async def list_roles( - current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) -): +async def list_roles(current_user: dict = Depends(get_current_user), db: Session = Depends(get_db)): """List all available roles (admin only)""" - user_role = UserRole(current_user.get('role', 'guest')) + user_role = UserRole(current_user.get("role", "guest")) if not RBACManager.has_permission(user_role, Permission.USER_READ): raise HTTPException(status_code=403, detail="Insufficient permissions") - + try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT name, display_name, description, permissions FROM roles WHERE is_active = true @@ -97,19 +101,23 @@ async def list_roles( WHEN 'guest' THEN 6 ELSE 7 END - """)) - + """ + ) + ) + roles = [] for row in result: - roles.append(RoleInfo( - name=row.name, - display_name=row.display_name, - description=row.description, - permissions=row.permissions if isinstance(row.permissions, list) else [] - )) - + roles.append( + RoleInfo( + name=row.name, + display_name=row.display_name, + description=row.description, + permissions=row.permissions if isinstance(row.permissions, list) else [], + ) + ) + return roles - + except Exception as e: logger.error(f"Error listing roles: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve roles") @@ -123,72 +131,79 @@ async def list_users( role: Optional[UserRole] = Query(None), is_active: Optional[bool] = Query(None), current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """List users with pagination and filtering""" - user_role = UserRole(current_user.get('role', 'guest')) + user_role = UserRole(current_user.get("role", "guest")) if not RBACManager.has_permission(user_role, Permission.USER_READ): raise HTTPException(status_code=403, detail="Insufficient permissions") - + try: # Build query conditions conditions = ["1=1"] params = {} - + if search: conditions.append("(username ILIKE :search OR email ILIKE :search)") params["search"] = f"%{search}%" - + if role: conditions.append("role = :role") params["role"] = role.value - + if is_active is not None: conditions.append("is_active = :is_active") params["is_active"] = is_active - + where_clause = " AND ".join(conditions) - + # Get total count - count_result = db.execute(text(f""" + count_result = db.execute( + text( + f""" SELECT COUNT(*) as total FROM users WHERE {where_clause} - """), params) + """ + ), + params, + ) total = count_result.fetchone().total - + # Get paginated results offset = (page - 1) * page_size params.update({"limit": page_size, "offset": offset}) - - result = db.execute(text(f""" + + result = db.execute( + text( + f""" SELECT id, username, email, role, is_active, created_at, last_login, failed_login_attempts, locked_until FROM users WHERE {where_clause} ORDER BY created_at DESC LIMIT :limit OFFSET :offset - """), params) - + """ + ), + params, + ) + users = [] for row in result: - users.append(UserResponse( - id=row.id, - username=row.username, - email=row.email, - role=UserRole(row.role), - is_active=row.is_active, - created_at=row.created_at.isoformat(), - last_login=row.last_login.isoformat() if row.last_login else None, - failed_login_attempts=row.failed_login_attempts, - locked_until=row.locked_until.isoformat() if row.locked_until else None - )) - - return UserListResponse( - users=users, - total=total, - page=page, - page_size=page_size - ) - + users.append( + UserResponse( + id=row.id, + username=row.username, + email=row.email, + role=UserRole(row.role), + is_active=row.is_active, + created_at=row.created_at.isoformat(), + last_login=row.last_login.isoformat() if row.last_login else None, + failed_login_attempts=row.failed_login_attempts, + locked_until=row.locked_until.isoformat() if row.locked_until else None, + ) + ) + + return UserListResponse(users=users, total=total, page=page, page_size=page_size) + except Exception as e: logger.error(f"Error listing users: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve users") @@ -199,39 +214,49 @@ async def list_users( async def create_user( user_data: UserCreate, current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Create a new user (super admin only)""" try: # Check if username or email already exists - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id FROM users WHERE username = :username OR email = :email - """), {"username": user_data.username, "email": user_data.email}) - + """ + ), + {"username": user_data.username, "email": user_data.email}, + ) + if result.fetchone(): raise HTTPException(status_code=400, detail="Username or email already exists") - + # Hash password hashed_password = pwd_context.hash(user_data.password) - + # Create user - insert_result = db.execute(text(""" + insert_result = db.execute( + text( + """ INSERT INTO users (username, email, hashed_password, role, is_active, created_at, failed_login_attempts, mfa_enabled) VALUES (:username, :email, :password, :role, :is_active, CURRENT_TIMESTAMP, 0, false) RETURNING id, created_at - """), { - "username": user_data.username, - "email": user_data.email, - "password": hashed_password, - "role": user_data.role.value, - "is_active": user_data.is_active - }) - + """ + ), + { + "username": user_data.username, + "email": user_data.email, + "password": hashed_password, + "role": user_data.role.value, + "is_active": user_data.is_active, + }, + ) + row = insert_result.fetchone() db.commit() - + logger.info(f"User {user_data.username} created by {current_user.get('username')}") - + return UserResponse( id=row.id, username=user_data.username, @@ -241,9 +266,9 @@ async def create_user( created_at=row.created_at.isoformat(), last_login=None, failed_login_attempts=0, - locked_until=None + locked_until=None, ) - + except HTTPException: raise except Exception as e: @@ -255,22 +280,25 @@ async def create_user( @router.get("/{user_id}", response_model=UserResponse) @require_permission(Permission.USER_READ) async def get_user( - user_id: int, - current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + user_id: int, current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """Get user by ID""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, username, email, role, is_active, created_at, last_login, failed_login_attempts, locked_until FROM users WHERE id = :user_id - """), {"user_id": user_id}) - + """ + ), + {"user_id": user_id}, + ) + row = result.fetchone() if not row: raise HTTPException(status_code=404, detail="User not found") - + return UserResponse( id=row.id, username=row.username, @@ -280,9 +308,9 @@ async def get_user( created_at=row.created_at.isoformat(), last_login=row.last_login.isoformat() if row.last_login else None, failed_login_attempts=row.failed_login_attempts, - locked_until=row.locked_until.isoformat() if row.locked_until else None + locked_until=row.locked_until.isoformat() if row.locked_until else None, ) - + except HTTPException: raise except Exception as e: @@ -296,72 +324,74 @@ async def update_user( user_id: int, user_data: UserUpdate, current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Update user (admin only, or users can update themselves)""" try: # Check if user exists - result = db.execute(text("SELECT id, role FROM users WHERE id = :user_id"), {"user_id": user_id}) + result = db.execute( + text("SELECT id, role FROM users WHERE id = :user_id"), {"user_id": user_id} + ) existing_user = result.fetchone() if not existing_user: raise HTTPException(status_code=404, detail="User not found") - + # Non-super admins can only update themselves (except for role changes) - current_user_role = UserRole(current_user.get('role', 'guest')) - is_self_update = current_user.get('id') == user_id - + current_user_role = UserRole(current_user.get("role", "guest")) + is_self_update = current_user.get("id") == user_id + if not RBACManager.has_permission(current_user_role, Permission.USER_MANAGE_ROLES): if not is_self_update: raise HTTPException(status_code=403, detail="Can only update your own profile") if user_data.role and user_data.role != UserRole(existing_user.role): raise HTTPException(status_code=403, detail="Cannot change your own role") - + # Build update query with secure column mapping updates = [] params = {"user_id": user_id} - + # Security Fix: Use explicit column mapping instead of f-string concatenation allowed_columns = { "username": "username = :username", - "email": "email = :email", + "email": "email = :email", "role": "role = :role", "is_active": "is_active = :is_active", - "password": "hashed_password = :password" + "password": "hashed_password = :password", } - + if user_data.username: updates.append(allowed_columns["username"]) params["username"] = user_data.username - + if user_data.email: updates.append(allowed_columns["email"]) params["email"] = user_data.email - + if user_data.role: updates.append(allowed_columns["role"]) params["role"] = user_data.role.value - + if user_data.is_active is not None: updates.append(allowed_columns["is_active"]) params["is_active"] = user_data.is_active - + if user_data.password: updates.append(allowed_columns["password"]) params["password"] = pwd_context.hash(user_data.password) - + if not updates: raise HTTPException(status_code=400, detail="No fields to update") - + updates.append("updated_at = CURRENT_TIMESTAMP") # Security Fix: Use parameterized query construction update_query = "UPDATE users SET " + ", ".join(updates) + " WHERE id = :user_id" - + db.execute(text(update_query), params) db.commit() - + # Return updated user return await get_user(user_id, current_user, db) - + except HTTPException: raise except Exception as e: @@ -371,36 +401,41 @@ async def update_user( @router.delete("/{user_id}") -@require_permission(Permission.USER_DELETE) +@require_permission(Permission.USER_DELETE) async def delete_user( - user_id: int, - current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + user_id: int, current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """Delete user (super admin only)""" try: # Prevent self-deletion - if current_user.get('id') == user_id: + if current_user.get("id") == user_id: raise HTTPException(status_code=400, detail="Cannot delete your own account") - + # Check if user exists - result = db.execute(text("SELECT username FROM users WHERE id = :user_id"), {"user_id": user_id}) + result = db.execute( + text("SELECT username FROM users WHERE id = :user_id"), {"user_id": user_id} + ) user = result.fetchone() if not user: raise HTTPException(status_code=404, detail="User not found") - + # Soft delete (deactivate) instead of hard delete to preserve audit trails - db.execute(text(""" + db.execute( + text( + """ UPDATE users SET is_active = false, updated_at = CURRENT_TIMESTAMP WHERE id = :user_id - """), {"user_id": user_id}) - + """ + ), + {"user_id": user_id}, + ) + db.commit() - + logger.info(f"User {user.username} deactivated by {current_user.get('username')}") return {"message": "User deactivated successfully"} - + except HTTPException: raise except Exception as e: @@ -413,37 +448,44 @@ async def delete_user( async def change_password( password_data: PasswordChangeRequest, current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Change current user's password""" try: - user_id = current_user.get('id') - + user_id = current_user.get("id") + # Get current hashed password - result = db.execute(text("SELECT hashed_password FROM users WHERE id = :user_id"), {"user_id": user_id}) + result = db.execute( + text("SELECT hashed_password FROM users WHERE id = :user_id"), {"user_id": user_id} + ) user = result.fetchone() if not user: raise HTTPException(status_code=404, detail="User not found") - + # Verify current password if not pwd_context.verify(password_data.current_password, user.hashed_password): raise HTTPException(status_code=400, detail="Current password is incorrect") - + # Hash new password new_hashed = pwd_context.hash(password_data.new_password) - + # Update password - db.execute(text(""" + db.execute( + text( + """ UPDATE users SET hashed_password = :password, updated_at = CURRENT_TIMESTAMP WHERE id = :user_id - """), {"password": new_hashed, "user_id": user_id}) - + """ + ), + {"password": new_hashed, "user_id": user_id}, + ) + db.commit() - + logger.info(f"Password changed for user {current_user.get('username')}") return {"message": "Password changed successfully"} - + except HTTPException: raise except Exception as e: @@ -454,20 +496,19 @@ async def change_password( @router.get("/me/profile", response_model=UserResponse) async def get_my_profile( - current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """Get current user's profile""" - return await get_user(current_user.get('id'), current_user, db) + return await get_user(current_user.get("id"), current_user, db) @router.put("/me/profile", response_model=UserResponse) async def update_my_profile( user_data: UserUpdate, current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """Update current user's profile""" # Remove role from update data - users cannot change their own role user_data.role = None - return await update_user(current_user.get('id'), user_data, current_user, db) \ No newline at end of file + return await update_user(current_user.get("id"), user_data, current_user, db) diff --git a/backend/app/routes/v1/__init__.py b/backend/app/routes/v1/__init__.py index 277b3988..aebf0681 100644 --- a/backend/app/routes/v1/__init__.py +++ b/backend/app/routes/v1/__init__.py @@ -1,4 +1,4 @@ """ OpenWatch API v1 - Unified API Façade Versioned API endpoints for consistent interface across OSS and Enterprise -""" \ No newline at end of file +""" diff --git a/backend/app/routes/v1/api.py b/backend/app/routes/v1/api.py index 59977e9e..08e4fc59 100644 --- a/backend/app/routes/v1/api.py +++ b/backend/app/routes/v1/api.py @@ -2,6 +2,7 @@ OpenWatch API v1 - Main Router Unified API façade with versioned endpoints and capability-based routing """ + from fastapi import APIRouter, Depends, HTTPException, status from typing import Dict, Any import logging @@ -24,7 +25,9 @@ # Include v1 enhanced endpoints router.include_router(v1_hosts.router, prefix="/hosts", tags=["Host Management v1"]) router.include_router(v1_scans.router, prefix="/scans", tags=["Scan Management v1"]) -router.include_router(v1_remediation.router, prefix="/remediation", tags=["Remediation Provider v1"]) +router.include_router( + v1_remediation.router, prefix="/remediation", tags=["Remediation Provider v1"] +) router.include_router(v1_openapi.router, prefix="/docs", tags=["API Documentation v1"]) @@ -32,7 +35,7 @@ async def get_api_info(): """ Get API v1 information and available endpoints - + Returns comprehensive information about the v1 API including available endpoints, authentication requirements, and capabilities. """ @@ -45,25 +48,25 @@ async def get_api_info(): "authentication": { "type": "JWT Bearer Token", "login_endpoint": "/api/auth/login", - "refresh_endpoint": "/api/auth/refresh" + "refresh_endpoint": "/api/auth/refresh", }, "endpoints": { "capabilities": "/api/v1/capabilities", - "features": "/api/v1/features", + "features": "/api/v1/features", "hosts": "/api/v1/hosts", "scans": "/api/v1/scans", "remediation": "/api/v1/remediation", - "integrations": "/api/v1/health/integrations" + "integrations": "/api/v1/health/integrations", }, "rate_limits": { "default": "1000 requests per minute", - "authenticated": "5000 requests per minute" + "authenticated": "5000 requests per minute", }, "support": { "documentation": "https://docs.openwatch.io", "community": "https://github.com/hanalyx/openwatch/discussions", - "enterprise": "https://hanalyx.com/support" - } + "enterprise": "https://hanalyx.com/support", + }, } @@ -71,23 +74,19 @@ async def get_api_info(): async def get_api_health(): """ Get API v1 health status - + Returns the health status of the v1 API and its dependencies. """ return { "status": "healthy", "version": "v1", "timestamp": "2025-08-20T12:00:00Z", - "dependencies": { - "database": "healthy", - "redis": "healthy", - "plugins": "healthy" - }, + "dependencies": {"database": "healthy", "redis": "healthy", "plugins": "healthy"}, "metrics": { "requests_per_minute": 0, "average_response_time": "50ms", - "error_rate": "0.1%" - } + "error_rate": "0.1%", + }, } @@ -95,7 +94,7 @@ async def get_api_health(): async def get_openapi_spec(): """ Get OpenAPI specification for API v1 - + Returns the complete OpenAPI 3.0 specification for the v1 API. """ # This would typically return the actual OpenAPI spec @@ -109,31 +108,20 @@ async def get_openapi_spec(): "contact": { "name": "OpenWatch Team", "url": "https://github.com/hanalyx/openwatch", - "email": "support@hanalyx.com" + "email": "support@hanalyx.com", }, - "license": { - "name": "Apache 2.0", - "url": "https://opensource.org/licenses/Apache-2.0" - } + "license": {"name": "Apache 2.0", "url": "https://opensource.org/licenses/Apache-2.0"}, }, - "servers": [ - { - "url": "/api/v1", - "description": "OpenWatch API v1" - } - ], + "servers": [{"url": "/api/v1", "description": "OpenWatch API v1"}], "tags": [ { "name": "System Capabilities", - "description": "Feature discovery and capability management" + "description": "Feature discovery and capability management", }, { - "name": "Host Management v1", - "description": "Host inventory and management operations" + "name": "Host Management v1", + "description": "Host inventory and management operations", }, - { - "name": "Scan Management v1", - "description": "SCAP scanning operations and results" - } - ] - } \ No newline at end of file + {"name": "Scan Management v1", "description": "SCAP scanning operations and results"}, + ], + } diff --git a/backend/app/routes/v1/hosts.py b/backend/app/routes/v1/hosts.py index 2b16b38d..d5f83ada 100644 --- a/backend/app/routes/v1/hosts.py +++ b/backend/app/routes/v1/hosts.py @@ -2,6 +2,7 @@ OpenWatch API v1 - Host Management Versioned host management endpoints with enhanced capabilities """ + from fastapi import APIRouter, Depends, HTTPException, status, Query from typing import List, Optional from pydantic import BaseModel @@ -27,12 +28,10 @@ # Add v1-specific enhancements @router.get("/capabilities") -async def get_host_management_capabilities( - current_user: dict = Depends(get_current_user) -): +async def get_host_management_capabilities(current_user: dict = Depends(get_current_user)): """ Get host management capabilities for API v1 - + Returns information about available host management features, limits, and supported operations in the v1 API. """ @@ -44,32 +43,30 @@ async def get_host_management_capabilities( "host_groups": True, "ssh_key_management": True, "remote_scanning": True, - "monitoring": True + "monitoring": True, }, "limits": { "max_hosts_per_request": 100, "bulk_import_max_size": 10000, - "supported_os": ["linux", "unix", "rhel", "ubuntu", "debian", "centos"] + "supported_os": ["linux", "unix", "rhel", "ubuntu", "debian", "centos"], }, "endpoints": { "list_hosts": "GET /api/v1/hosts", - "create_host": "POST /api/v1/hosts", + "create_host": "POST /api/v1/hosts", "get_host": "GET /api/v1/hosts/{host_id}", "update_host": "PUT /api/v1/hosts/{host_id}", "delete_host": "DELETE /api/v1/hosts/{host_id}", "bulk_import": "POST /api/v1/hosts/bulk", - "capabilities": "GET /api/v1/hosts/capabilities" - } + "capabilities": "GET /api/v1/hosts/capabilities", + }, } @router.get("/summary") -async def get_hosts_summary( - current_user: dict = Depends(get_current_user) -): +async def get_hosts_summary(current_user: dict = Depends(get_current_user)): """ Get summary statistics for host management (v1 specific) - + Returns aggregate information about hosts, groups, and management status. """ # This would typically query the database for actual statistics @@ -78,15 +75,7 @@ async def get_hosts_summary( "active_hosts": 0, "groups": 0, "last_scan": None, - "compliance_summary": { - "compliant": 0, - "non_compliant": 0, - "unknown": 0 - }, + "compliance_summary": {"compliant": 0, "non_compliant": 0, "unknown": 0}, "os_distribution": {}, - "scan_status": { - "never_scanned": 0, - "recently_scanned": 0, - "outdated_scans": 0 - } - } \ No newline at end of file + "scan_status": {"never_scanned": 0, "recently_scanned": 0, "outdated_scans": 0}, + } diff --git a/backend/app/routes/v1/openapi.py b/backend/app/routes/v1/openapi.py index c9d5e293..b0d07d02 100644 --- a/backend/app/routes/v1/openapi.py +++ b/backend/app/routes/v1/openapi.py @@ -2,6 +2,7 @@ OpenWatch API v1 - OpenAPI Specification Export Enhanced OpenAPI documentation with comprehensive examples and descriptions """ + from fastapi import APIRouter, Depends, HTTPException, status, Query, Response from fastapi.openapi.utils import get_openapi from typing import Dict, Any, Optional @@ -10,6 +11,7 @@ import yaml from ...auth import get_current_user + # from ...main import app # Removed to avoid circular import logger = logging.getLogger(__name__) @@ -20,190 +22,189 @@ @router.get("/spec.json") async def get_openapi_spec_json( include_internal: bool = Query(False, description="Include internal endpoints"), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ) -> Dict[str, Any]: """ Get OpenAPI specification in JSON format - + Returns the complete OpenAPI 3.0 specification for the OpenWatch API v1 with comprehensive documentation, examples, and security schemes. """ try: # Generate enhanced OpenAPI spec openapi_spec = _generate_enhanced_openapi_spec(include_internal) - - logger.debug(f"OpenAPI JSON spec requested by user {current_user.get('user_id', 'unknown')}") - + + logger.debug( + f"OpenAPI JSON spec requested by user {current_user.get('user_id', 'unknown')}" + ) + return openapi_spec - + except Exception as e: logger.error(f"Error generating OpenAPI spec: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to generate OpenAPI specification" + detail="Failed to generate OpenAPI specification", ) @router.get("/spec.yaml") async def get_openapi_spec_yaml( include_internal: bool = Query(False, description="Include internal endpoints"), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """ Get OpenAPI specification in YAML format - + Returns the complete OpenAPI 3.0 specification in YAML format, suitable for use with API development tools and documentation generators. """ try: # Generate enhanced OpenAPI spec openapi_spec = _generate_enhanced_openapi_spec(include_internal) - + # Convert to YAML yaml_content = yaml.dump(openapi_spec, default_flow_style=False, sort_keys=False) - - logger.debug(f"OpenAPI YAML spec requested by user {current_user.get('user_id', 'unknown')}") - + + logger.debug( + f"OpenAPI YAML spec requested by user {current_user.get('user_id', 'unknown')}" + ) + return Response( content=yaml_content, media_type="application/x-yaml", - headers={"Content-Disposition": "attachment; filename=openwatch-api-v1.yaml"} + headers={"Content-Disposition": "attachment; filename=openwatch-api-v1.yaml"}, ) - + except Exception as e: logger.error(f"Error generating OpenAPI YAML spec: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to generate OpenAPI YAML specification" + detail="Failed to generate OpenAPI YAML specification", ) @router.get("/postman") -async def get_postman_collection( - current_user: dict = Depends(get_current_user) -) -> Dict[str, Any]: +async def get_postman_collection(current_user: dict = Depends(get_current_user)) -> Dict[str, Any]: """ Get Postman collection for OpenWatch API v1 - + Returns a Postman collection with pre-configured requests for all API endpoints, including authentication and example payloads. """ try: collection = _generate_postman_collection() - - logger.debug(f"Postman collection requested by user {current_user.get('user_id', 'unknown')}") - + + logger.debug( + f"Postman collection requested by user {current_user.get('user_id', 'unknown')}" + ) + return collection - + except Exception as e: logger.error(f"Error generating Postman collection: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to generate Postman collection" + detail="Failed to generate Postman collection", ) @router.get("/sdk/examples") async def get_sdk_examples( language: str = Query("python", description="Programming language for examples"), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ) -> Dict[str, Any]: """ Get SDK examples for different programming languages - + Returns code examples showing how to integrate with the OpenWatch API using various programming languages and HTTP clients. """ try: examples = _generate_sdk_examples(language) - - logger.debug(f"SDK examples ({language}) requested by user {current_user.get('user_id', 'unknown')}") - + + logger.debug( + f"SDK examples ({language}) requested by user {current_user.get('user_id', 'unknown')}" + ) + return examples - + except Exception as e: logger.error(f"Error generating SDK examples: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to generate SDK examples" + detail="Failed to generate SDK examples", ) def _generate_enhanced_openapi_spec(include_internal: bool = False) -> Dict[str, Any]: """Generate enhanced OpenAPI specification with comprehensive documentation""" - + # Generate a basic OpenAPI spec (without app.routes to avoid circular import) openapi_spec = { "openapi": "3.0.0", "info": { "title": "OpenWatch SCAP Compliance Scanner API", "version": "1.0.0", - "description": _get_api_description() + "description": _get_api_description(), }, "paths": {}, # Would be populated with actual routes - "tags": _get_openapi_tags() + "tags": _get_openapi_tags(), } - + # Enhance the specification - openapi_spec.update({ - "info": { - **openapi_spec["info"], - "contact": { - "name": "OpenWatch Team", - "url": "https://github.com/hanalyx/openwatch", - "email": "support@hanalyx.com" - }, - "license": { - "name": "Apache 2.0", - "url": "https://opensource.org/licenses/Apache-2.0" + openapi_spec.update( + { + "info": { + **openapi_spec["info"], + "contact": { + "name": "OpenWatch Team", + "url": "https://github.com/hanalyx/openwatch", + "email": "support@hanalyx.com", + }, + "license": { + "name": "Apache 2.0", + "url": "https://opensource.org/licenses/Apache-2.0", + }, + "termsOfService": "https://hanalyx.com/terms", }, - "termsOfService": "https://hanalyx.com/terms" - }, - "servers": [ - { - "url": "/api/v1", - "description": "OpenWatch API v1" + "servers": [ + {"url": "/api/v1", "description": "OpenWatch API v1"}, + {"url": "https://demo.openwatch.io/api/v1", "description": "Demo Environment"}, + ], + "externalDocs": { + "description": "OpenWatch Documentation", + "url": "https://docs.openwatch.io", }, - { - "url": "https://demo.openwatch.io/api/v1", - "description": "Demo Environment" - } - ], - "externalDocs": { - "description": "OpenWatch Documentation", - "url": "https://docs.openwatch.io" } - }) - + ) + # Add security schemes openapi_spec["components"]["securitySchemes"] = { "BearerAuth": { "type": "http", "scheme": "bearer", "bearerFormat": "JWT", - "description": "JWT Bearer token authentication. Obtain token from /api/auth/login" + "description": "JWT Bearer token authentication. Obtain token from /api/auth/login", }, "ApiKeyAuth": { "type": "apiKey", "in": "header", "name": "X-API-Key", - "description": "API key authentication for automated integrations" - } + "description": "API key authentication for automated integrations", + }, } - + # Add global security requirement - openapi_spec["security"] = [ - {"BearerAuth": []}, - {"ApiKeyAuth": []} - ] - + openapi_spec["security"] = [{"BearerAuth": []}, {"ApiKeyAuth": []}] + # Add enhanced examples and schemas _add_enhanced_examples(openapi_spec) - + # Filter internal endpoints if requested if not include_internal: _filter_internal_endpoints(openapi_spec) - + return openapi_spec @@ -244,120 +245,93 @@ def _get_openapi_tags() -> list: return [ { "name": "System Capabilities", - "description": "Feature discovery and capability management" - }, - { - "name": "Host Management v1", - "description": "Host inventory and management operations" - }, - { - "name": "Scan Management v1", - "description": "SCAP scanning operations and results" + "description": "Feature discovery and capability management", }, + {"name": "Host Management v1", "description": "Host inventory and management operations"}, + {"name": "Scan Management v1", "description": "SCAP scanning operations and results"}, { "name": "Remediation Provider v1", - "description": "Automated and manual remediation operations" + "description": "Automated and manual remediation operations", }, - { - "name": "Authentication", - "description": "User authentication and session management" - } + {"name": "Authentication", "description": "User authentication and session management"}, ] def _add_enhanced_examples(openapi_spec: Dict[str, Any]): """Add enhanced examples to the OpenAPI specification""" - + # Add common schema examples if "components" not in openapi_spec: openapi_spec["components"] = {} - + if "examples" not in openapi_spec["components"]: openapi_spec["components"]["examples"] = {} - + # Add example responses - openapi_spec["components"]["examples"].update({ - "CapabilitiesResponse": { - "summary": "System capabilities example", - "value": { - "version": "1.0.0", - "features": { - "scanning": True, - "reporting": True, - "remediation": False, - "ai_assistance": False + openapi_spec["components"]["examples"].update( + { + "CapabilitiesResponse": { + "summary": "System capabilities example", + "value": { + "version": "1.0.0", + "features": { + "scanning": True, + "reporting": True, + "remediation": False, + "ai_assistance": False, + }, + "limits": {"max_hosts": 50, "concurrent_scans": 5}, + "integrations": {"aegis_available": False, "container_runtime": "podman"}, }, - "limits": { - "max_hosts": 50, - "concurrent_scans": 5 + }, + "RemediationJobResponse": { + "summary": "Remediation job example", + "value": { + "job_id": "123e4567-e89b-12d3-a456-426614174000", + "scan_id": "456e7890-e89b-12d3-a456-426614174001", + "status": "pending", + "provider": "aegis", + "failed_rules": ["xccdf_rule_1", "xccdf_rule_2"], + "created_at": "2025-08-20T12:00:00Z", }, - "integrations": { - "aegis_available": False, - "container_runtime": "podman" - } - } - }, - "RemediationJobResponse": { - "summary": "Remediation job example", - "value": { - "job_id": "123e4567-e89b-12d3-a456-426614174000", - "scan_id": "456e7890-e89b-12d3-a456-426614174001", - "status": "pending", - "provider": "aegis", - "failed_rules": ["xccdf_rule_1", "xccdf_rule_2"], - "created_at": "2025-08-20T12:00:00Z" - } + }, } - }) + ) def _filter_internal_endpoints(openapi_spec: Dict[str, Any]): """Filter out internal endpoints from the specification""" - + # Remove internal paths internal_patterns = ["/health", "/metrics", "/debug"] - + if "paths" in openapi_spec: paths_to_remove = [] for path in openapi_spec["paths"]: if any(pattern in path for pattern in internal_patterns): paths_to_remove.append(path) - + for path in paths_to_remove: del openapi_spec["paths"][path] def _generate_postman_collection() -> Dict[str, Any]: """Generate Postman collection for the API""" - + return { "info": { "name": "OpenWatch API v1", "description": "OpenWatch SCAP Compliance Scanner API collection", "version": "1.0.0", - "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json" + "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json", }, "auth": { "type": "bearer", - "bearer": [ - { - "key": "token", - "value": "{{auth_token}}", - "type": "string" - } - ] + "bearer": [{"key": "token", "value": "{{auth_token}}", "type": "string"}], }, "variable": [ - { - "key": "base_url", - "value": "http://localhost:8000/api/v1", - "type": "string" - }, - { - "key": "auth_token", - "value": "", - "type": "string" - } + {"key": "base_url", "value": "http://localhost:8000/api/v1", "type": "string"}, + {"key": "auth_token", "value": "", "type": "string"}, ], "item": [ { @@ -367,27 +341,21 @@ def _generate_postman_collection() -> Dict[str, Any]: "name": "Login", "request": { "method": "POST", - "header": [ - { - "key": "Content-Type", - "value": "application/json" - } - ], + "header": [{"key": "Content-Type", "value": "application/json"}], "body": { "mode": "raw", - "raw": json.dumps({ - "username": "admin", - "password": "your_password" - }) + "raw": json.dumps( + {"username": "admin", "password": "your_password"} + ), }, "url": { "raw": "{{base_url}}/../auth/login", "host": ["{{base_url}}"], - "path": ["..", "auth", "login"] - } - } + "path": ["..", "auth", "login"], + }, + }, } - ] + ], }, { "name": "Capabilities", @@ -399,11 +367,11 @@ def _generate_postman_collection() -> Dict[str, Any]: "url": { "raw": "{{base_url}}/capabilities", "host": ["{{base_url}}"], - "path": ["capabilities"] - } - } + "path": ["capabilities"], + }, + }, } - ] + ], }, { "name": "Hosts", @@ -415,19 +383,19 @@ def _generate_postman_collection() -> Dict[str, Any]: "url": { "raw": "{{base_url}}/hosts", "host": ["{{base_url}}"], - "path": ["hosts"] - } - } + "path": ["hosts"], + }, + }, } - ] - } - ] + ], + }, + ], } def _generate_sdk_examples(language: str) -> Dict[str, Any]: """Generate SDK examples for different languages""" - + examples = { "python": { "authentication": """ @@ -459,7 +427,7 @@ def _generate_sdk_examples(language: str) -> Dict[str, Any]: # Check scan status status = requests.get(f"http://localhost:8000/api/v1/scans/{scan_id}", headers=headers) print(status.json()) - """ + """, }, "curl": { "authentication": """ @@ -478,7 +446,7 @@ def _generate_sdk_examples(language: str) -> Dict[str, Any]: -H "Authorization: Bearer YOUR_TOKEN_HERE" \\ -H "Content-Type: application/json" \\ -d '{"host_id": "host-123", "profile_id": "stig-rhel8"}' - """ + """, }, "javascript": { "authentication": """ @@ -496,11 +464,11 @@ def _generate_sdk_examples(language: str) -> Dict[str, Any]: }); console.log(await capabilities.json()); """ - } + }, } - + return { "language": language, "examples": examples.get(language, {}), - "available_languages": list(examples.keys()) - } \ No newline at end of file + "available_languages": list(examples.keys()), + } diff --git a/backend/app/routes/v1/remediation.py b/backend/app/routes/v1/remediation.py index cbabd668..341eab83 100644 --- a/backend/app/routes/v1/remediation.py +++ b/backend/app/routes/v1/remediation.py @@ -2,6 +2,7 @@ OpenWatch API v1 - Remediation Provider Interface Enhanced remediation interface for AEGIS integration and other remediation providers """ + from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks, Query from pydantic import BaseModel, Field, UUID4 from typing import List, Optional, Dict, Any @@ -23,6 +24,7 @@ class RemediationRequest(BaseModel): """Request to start remediation for scan results""" + scan_id: UUID4 host_id: UUID4 failed_rules: List[str] = Field(..., min_items=1) @@ -34,6 +36,7 @@ class RemediationRequest(BaseModel): class RemediationJob(BaseModel): """Remediation job status and information""" + job_id: UUID4 scan_id: UUID4 host_id: UUID4 @@ -52,6 +55,7 @@ class RemediationJob(BaseModel): class RemediationProvider(BaseModel): """Information about a remediation provider""" + name: str version: str status: str # 'available', 'unavailable', 'degraded' @@ -63,6 +67,7 @@ class RemediationProvider(BaseModel): class RemediationSummary(BaseModel): """Summary of remediation activities""" + total_jobs: int active_jobs: int completed_jobs: int @@ -78,11 +83,11 @@ async def start_remediation( request: RemediationRequest, background_tasks: BackgroundTasks, current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ) -> RemediationJob: """ Start remediation for failed scan rules - + Initiates a remediation job for the specified failed rules using the configured remediation provider (AEGIS, Ansible, etc.). """ @@ -90,29 +95,23 @@ async def start_remediation( # Verify scan exists and user has access scan = db.query(Scan).filter(Scan.id == str(request.scan_id)).first() if not scan: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Scan not found" - ) - + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Scan not found") + # Verify host exists host = db.query(Host).filter(Host.id == str(request.host_id)).first() if not host: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Host not found" - ) - + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Host not found") + # Check if remediation is already in progress if scan.remediation_status in ["pending", "running"]: raise HTTPException( status_code=status.HTTP_409_CONFLICT, - detail="Remediation already in progress for this scan" + detail="Remediation already in progress for this scan", ) - + # Generate job ID job_id = uuid.uuid4() - + # Create remediation job job = RemediationJob( job_id=job_id, @@ -126,23 +125,23 @@ async def start_remediation( metadata={ "user_id": current_user.get("user_id"), "options": request.options, - "rule_count": len(request.failed_rules) - } + "rule_count": len(request.failed_rules), + }, ) - + # Update scan status scan.remediation_requested = True scan.remediation_status = "pending" scan.aegis_remediation_id = str(job_id) - + # Store job information in scan metadata if not scan.metadata: scan.metadata = {} - + scan.metadata["remediation_job"] = job.dict() - + db.commit() - + # Log audit event await log_audit_event( db=db, @@ -154,11 +153,11 @@ async def start_remediation( "job_id": str(job_id), "provider": request.provider, "rule_count": len(request.failed_rules), - "priority": request.priority + "priority": request.priority, }, - ip_address="127.0.0.1" + ip_address="127.0.0.1", ) - + # Start remediation in background background_tasks.add_task( _execute_remediation_job, @@ -167,103 +166,92 @@ async def start_remediation( scan_id=request.scan_id, host_id=request.host_id, failed_rules=request.failed_rules, - options=request.options + options=request.options, ) - + logger.info(f"Remediation job {job_id} started for scan {scan.id}") - + return job - + except HTTPException: raise except Exception as e: logger.error(f"Error starting remediation: {e}") raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to start remediation" + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to start remediation" ) @router.get("/job/{job_id}", response_model=RemediationJob) async def get_remediation_job( - job_id: UUID4, - current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + job_id: UUID4, current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ) -> RemediationJob: """ Get remediation job status and details - + Returns detailed information about a specific remediation job including progress, results, and current status. """ try: # Find scan with this remediation job ID - scan = db.query(Scan).filter( - Scan.aegis_remediation_id == str(job_id) - ).first() - + scan = db.query(Scan).filter(Scan.aegis_remediation_id == str(job_id)).first() + if not scan or not scan.metadata or "remediation_job" not in scan.metadata: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Remediation job not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Remediation job not found" ) - + job_data = scan.metadata["remediation_job"] job = RemediationJob(**job_data) - + # Update with latest status from scan job.status = scan.remediation_status or "unknown" if scan.remediation_completed_at: job.completed_at = scan.remediation_completed_at - + return job - + except HTTPException: raise except Exception as e: logger.error(f"Error getting remediation job: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve remediation job" + detail="Failed to retrieve remediation job", ) @router.delete("/job/{job_id}") async def cancel_remediation_job( - job_id: UUID4, - current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + job_id: UUID4, current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ): """ Cancel a running remediation job - + Attempts to cancel a remediation job that is currently pending or running. """ try: # Find scan with this remediation job ID - scan = db.query(Scan).filter( - Scan.aegis_remediation_id == str(job_id) - ).first() - + scan = db.query(Scan).filter(Scan.aegis_remediation_id == str(job_id)).first() + if not scan: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Remediation job not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Remediation job not found" ) - + if scan.remediation_status not in ["pending", "running"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot cancel job in current status" + detail="Cannot cancel job in current status", ) - + # Update status scan.remediation_status = "cancelled" - + # TODO: Send cancellation request to remediation provider - + db.commit() - + # Log audit event await log_audit_event( db=db, @@ -272,20 +260,20 @@ async def cancel_remediation_job( resource_type="scan", resource_id=str(scan.id), details={"job_id": str(job_id)}, - ip_address="127.0.0.1" + ip_address="127.0.0.1", ) - + logger.info(f"Remediation job {job_id} cancelled") - + return {"status": "cancelled", "job_id": str(job_id)} - + except HTTPException: raise except Exception as e: logger.error(f"Error cancelling remediation job: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to cancel remediation job" + detail="Failed to cancel remediation job", ) @@ -294,196 +282,194 @@ async def retry_remediation_job( job_id: UUID4, failed_rules_only: bool = Query(True, description="Retry only failed rules"), current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Retry a failed remediation job - + Creates a new remediation job based on a previously failed job, optionally retrying only the rules that failed. """ try: # Find original scan - scan = db.query(Scan).filter( - Scan.aegis_remediation_id == str(job_id) - ).first() - + scan = db.query(Scan).filter(Scan.aegis_remediation_id == str(job_id)).first() + if not scan: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Original remediation job not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Original remediation job not found" ) - + if scan.remediation_status not in ["failed", "partial"]: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Can only retry failed or partial remediation jobs" + detail="Can only retry failed or partial remediation jobs", ) - + # Get failed rules from original job original_job = scan.metadata.get("remediation_job", {}) if failed_rules_only and "results" in scan.metadata.get("remediation", {}): # Extract rules that failed failed_rules = [ - r["rule_id"] for r in scan.metadata["remediation"]["results"] + r["rule_id"] + for r in scan.metadata["remediation"]["results"] if r["status"] == "failed" ] else: # Retry all original rules failed_rules = original_job.get("failed_rules", []) - + if not failed_rules: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="No failed rules to retry" + status_code=status.HTTP_400_BAD_REQUEST, detail="No failed rules to retry" ) - + # Create new retry request retry_request = RemediationRequest( scan_id=UUID4(str(scan.id)), host_id=UUID4(str(scan.host_id)), failed_rules=failed_rules, provider=original_job.get("provider", "aegis"), - priority=original_job.get("priority", "medium") + priority=original_job.get("priority", "medium"), ) - + # Start new remediation job new_job = await start_remediation(retry_request, BackgroundTasks(), current_user, db) - + logger.info(f"Retry remediation job {new_job.job_id} created for original job {job_id}") - + return { "status": "retry_started", "original_job_id": str(job_id), "new_job_id": str(new_job.job_id), - "rules_to_retry": len(failed_rules) + "rules_to_retry": len(failed_rules), } - + except HTTPException: raise except Exception as e: logger.error(f"Error retrying remediation job: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retry remediation job" + detail="Failed to retry remediation job", ) @router.get("/providers", response_model=List[RemediationProvider]) async def get_remediation_providers( - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ) -> List[RemediationProvider]: """ Get available remediation providers - + Returns information about all configured remediation providers including their status, capabilities, and configuration. """ try: providers = [] - + # AEGIS Provider aegis_status = await _check_aegis_status() - providers.append(RemediationProvider( - name="aegis", - version="1.0.0", - status=aegis_status["status"], - capabilities=[ - "automated_remediation", - "rule_based_fixes", - "rollback_support", - "verification_scans" - ], - supported_os=[ - "rhel8", "rhel9", "ubuntu20.04", "ubuntu22.04", - "centos8", "debian11" - ], - supported_frameworks=["STIG", "CIS", "PCI-DSS"], - configuration=aegis_status["config"] - )) - + providers.append( + RemediationProvider( + name="aegis", + version="1.0.0", + status=aegis_status["status"], + capabilities=[ + "automated_remediation", + "rule_based_fixes", + "rollback_support", + "verification_scans", + ], + supported_os=[ + "rhel8", + "rhel9", + "ubuntu20.04", + "ubuntu22.04", + "centos8", + "debian11", + ], + supported_frameworks=["STIG", "CIS", "PCI-DSS"], + configuration=aegis_status["config"], + ) + ) + # Ansible Provider (if configured) ansible_status = await _check_ansible_status() if ansible_status["available"]: - providers.append(RemediationProvider( - name="ansible", - version=ansible_status.get("version", "unknown"), + providers.append( + RemediationProvider( + name="ansible", + version=ansible_status.get("version", "unknown"), + status="available", + capabilities=[ + "playbook_execution", + "idempotent_operations", + "multi_host_support", + ], + supported_os=["linux", "unix"], + supported_frameworks=["custom"], + configuration=ansible_status["config"], + ) + ) + + # Manual Provider (always available) + providers.append( + RemediationProvider( + name="manual", + version="1.0.0", status="available", capabilities=[ - "playbook_execution", - "idempotent_operations", - "multi_host_support" + "guided_remediation", + "documentation_generation", + "compliance_tracking", ], - supported_os=["linux", "unix"], - supported_frameworks=["custom"], - configuration=ansible_status["config"] - )) - - # Manual Provider (always available) - providers.append(RemediationProvider( - name="manual", - version="1.0.0", - status="available", - capabilities=[ - "guided_remediation", - "documentation_generation", - "compliance_tracking" - ], - supported_os=["all"], - supported_frameworks=["all"], - configuration={"type": "manual"} - )) - + supported_os=["all"], + supported_frameworks=["all"], + configuration={"type": "manual"}, + ) + ) + return providers - + except Exception as e: logger.error(f"Error getting remediation providers: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve remediation providers" + detail="Failed to retrieve remediation providers", ) @router.get("/summary", response_model=RemediationSummary) async def get_remediation_summary( - current_user: dict = Depends(get_current_user), - db: Session = Depends(get_db) + current_user: dict = Depends(get_current_user), db: Session = Depends(get_db) ) -> RemediationSummary: """ Get remediation activity summary - + Returns aggregate statistics about remediation jobs and activities. """ try: # Query remediation jobs from scan metadata # This is a simplified implementation - in production you'd have a dedicated table - + # Count scans with remediation data - total_jobs = db.query(Scan).filter( - Scan.remediation_requested == True - ).count() - - active_jobs = db.query(Scan).filter( - Scan.remediation_status.in_(["pending", "running"]) - ).count() - - completed_jobs = db.query(Scan).filter( - Scan.remediation_status == "completed" - ).count() - - failed_jobs = db.query(Scan).filter( - Scan.remediation_status == "failed" - ).count() - - pending_jobs = db.query(Scan).filter( - Scan.remediation_status == "pending" - ).count() - + total_jobs = db.query(Scan).filter(Scan.remediation_requested == True).count() + + active_jobs = ( + db.query(Scan).filter(Scan.remediation_status.in_(["pending", "running"])).count() + ) + + completed_jobs = db.query(Scan).filter(Scan.remediation_status == "completed").count() + + failed_jobs = db.query(Scan).filter(Scan.remediation_status == "failed").count() + + pending_jobs = db.query(Scan).filter(Scan.remediation_status == "pending").count() + # Calculate success rate success_rate = 0.0 if total_jobs > 0: success_rate = (completed_jobs / total_jobs) * 100 - + return RemediationSummary( total_jobs=total_jobs, active_jobs=active_jobs, @@ -492,18 +478,14 @@ async def get_remediation_summary( pending_jobs=pending_jobs, success_rate=success_rate, average_duration_minutes=None, # Would calculate from actual data - last_24h={ - "jobs_started": 0, - "jobs_completed": 0, - "rules_fixed": 0 - } + last_24h={"jobs_started": 0, "jobs_completed": 0, "rules_fixed": 0}, ) - + except Exception as e: logger.error(f"Error getting remediation summary: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Failed to retrieve remediation summary" + detail="Failed to retrieve remediation summary", ) @@ -514,12 +496,12 @@ async def _execute_remediation_job( scan_id: UUID4, host_id: UUID4, failed_rules: List[str], - options: Dict[str, Any] + options: Dict[str, Any], ): """Execute remediation job in background""" try: logger.info(f"Starting remediation job {job_id} with provider {provider}") - + if provider == "aegis": await _execute_aegis_remediation(job_id, scan_id, host_id, failed_rules, options) elif provider == "ansible": @@ -528,7 +510,7 @@ async def _execute_remediation_job( await _execute_manual_remediation(job_id, scan_id, host_id, failed_rules, options) else: logger.error(f"Unknown remediation provider: {provider}") - + except Exception as e: logger.error(f"Error executing remediation job {job_id}: {e}") @@ -556,25 +538,18 @@ async def _execute_manual_remediation(job_id, scan_id, host_id, failed_rules, op logger.info(f"Manual remediation job {job_id} completed (simulated)") -async def _check_aegis_status(): +def _check_aegis_status(): """Check AEGIS provider status""" settings = get_settings() - aegis_url = getattr(settings, 'aegis_url', None) - + aegis_url = getattr(settings, "aegis_url", None) + if not aegis_url: - return { - "status": "unavailable", - "config": {"error": "AEGIS_URL not configured"} - } - + return {"status": "unavailable", "config": {"error": "AEGIS_URL not configured"}} + # Would check actual AEGIS connectivity here return { "status": "available", - "config": { - "url": aegis_url, - "webhook_configured": True, - "api_version": "v1" - } + "config": {"url": aegis_url, "webhook_configured": True, "api_version": "v1"}, } @@ -583,20 +558,14 @@ async def _check_ansible_status(): try: # Check if ansible is installed process = await asyncio.create_subprocess_exec( - 'ansible', '--version', - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + "ansible", "--version", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await process.communicate() - + if process.returncode == 0: - version = stdout.decode().split('\n')[0] - return { - "available": True, - "version": version, - "config": {"type": "ansible"} - } + version = stdout.decode().split("\n")[0] + return {"available": True, "version": version, "config": {"type": "ansible"}} except: pass - - return {"available": False} \ No newline at end of file + + return {"available": False} diff --git a/backend/app/routes/v1/scans.py b/backend/app/routes/v1/scans.py index 5749405b..ddaaadbd 100644 --- a/backend/app/routes/v1/scans.py +++ b/backend/app/routes/v1/scans.py @@ -2,6 +2,7 @@ OpenWatch API v1 - Scan Management Versioned scan management endpoints with enhanced capabilities """ + from fastapi import APIRouter, Depends, HTTPException, status, Query from typing import List, Optional from pydantic import BaseModel @@ -25,12 +26,10 @@ # Add v1-specific enhancements @router.get("/capabilities") -async def get_scan_capabilities( - current_user: dict = Depends(get_current_user) -): +async def get_scan_capabilities(current_user: dict = Depends(get_current_user)): """ Get scanning capabilities for API v1 - + Returns information about available scanning features, supported profiles, and scan limits. """ @@ -42,25 +41,25 @@ async def get_scan_capabilities( "custom_profiles": True, "scheduled_scanning": True, "bulk_scanning": True, - "real_time_progress": True + "real_time_progress": True, }, "limits": { "max_parallel_scans": 100, "max_hosts_per_scan": 1000, "scan_timeout_minutes": 60, - "max_scan_history": 10000 + "max_scan_history": 10000, }, "supported_formats": { "input": ["xml", "zip", "datastream"], - "output": ["xml", "html", "json", "arf"] + "output": ["xml", "html", "json", "arf"], }, "supported_profiles": [ "stig-rhel8", - "stig-rhel9", + "stig-rhel9", "cis-ubuntu-20.04", "cis-ubuntu-22.04", "pci-dss", - "custom" + "custom", ], "endpoints": { "list_scans": "GET /api/v1/scans", @@ -69,18 +68,16 @@ async def get_scan_capabilities( "cancel_scan": "DELETE /api/v1/scans/{scan_id}", "get_results": "GET /api/v1/scans/{scan_id}/results", "bulk_scan": "POST /api/v1/scans/bulk", - "capabilities": "GET /api/v1/scans/capabilities" - } + "capabilities": "GET /api/v1/scans/capabilities", + }, } @router.get("/summary") -async def get_scans_summary( - current_user: dict = Depends(get_current_user) -): +async def get_scans_summary(current_user: dict = Depends(get_current_user)): """ Get summary statistics for scan management (v1 specific) - + Returns aggregate information about scans, results, and compliance trends. """ return { @@ -88,28 +85,18 @@ async def get_scans_summary( "recent_scans": 0, "active_scans": 0, "failed_scans": 0, - "compliance_trend": { - "improving": 0, - "declining": 0, - "stable": 0 - }, + "compliance_trend": {"improving": 0, "declining": 0, "stable": 0}, "profile_usage": {}, "average_scan_time": None, - "last_24h": { - "scans_completed": 0, - "hosts_scanned": 0, - "critical_findings": 0 - } + "last_24h": {"scans_completed": 0, "hosts_scanned": 0, "critical_findings": 0}, } @router.get("/profiles") -async def get_available_profiles( - current_user: dict = Depends(get_current_user) -): +async def get_available_profiles(current_user: dict = Depends(get_current_user)): """ Get available SCAP profiles for scanning (v1 specific) - + Returns list of available profiles with metadata and compatibility info. """ return { @@ -122,11 +109,7 @@ async def get_available_profiles( "rules_count": 335, "supported_os": ["rhel8", "centos8"], "compliance_frameworks": ["STIG", "NIST"], - "severity_distribution": { - "high": 45, - "medium": 180, - "low": 110 - } + "severity_distribution": {"high": 45, "medium": 180, "low": 110}, }, { "id": "cis-ubuntu-20.04", @@ -136,13 +119,9 @@ async def get_available_profiles( "rules_count": 267, "supported_os": ["ubuntu20.04"], "compliance_frameworks": ["CIS"], - "severity_distribution": { - "high": 38, - "medium": 156, - "low": 73 - } - } + "severity_distribution": {"high": 38, "medium": 156, "low": 73}, + }, ], "total_profiles": 2, - "custom_profiles_supported": True - } \ No newline at end of file + "custom_profiles_supported": True, + } diff --git a/backend/app/routes/webhooks.py b/backend/app/routes/webhooks.py index 40011769..c42dc1df 100644 --- a/backend/app/routes/webhooks.py +++ b/backend/app/routes/webhooks.py @@ -2,6 +2,7 @@ Webhook Management API Routes Handles webhook endpoint registration and delivery tracking for AEGIS integration """ + import uuid import hashlib import hmac @@ -28,24 +29,24 @@ class WebhookEndpointCreate(BaseModel): url: str event_types: List[str] secret: str - - @validator('event_types') + + @validator("event_types") def validate_event_types(cls, v): valid_events = [ - 'scan.completed', - 'scan.failed', - 'remediation.completed', - 'remediation.failed' + "scan.completed", + "scan.failed", + "remediation.completed", + "remediation.failed", ] for event in v: if event not in valid_events: raise ValueError(f"Invalid event type: {event}. Must be one of: {valid_events}") return v - - @validator('url') + + @validator("url") def validate_url(cls, v): - if not v.startswith(('http://', 'https://')): - raise ValueError('URL must start with http:// or https://') + if not v.startswith(("http://", "https://")): + raise ValueError("URL must start with http:// or https://") return v @@ -55,26 +56,26 @@ class WebhookEndpointUpdate(BaseModel): event_types: Optional[List[str]] = None secret: Optional[str] = None is_active: Optional[bool] = None - - @validator('event_types') + + @validator("event_types") def validate_event_types(cls, v): if v is None: return v valid_events = [ - 'scan.completed', - 'scan.failed', - 'remediation.completed', - 'remediation.failed' + "scan.completed", + "scan.failed", + "remediation.completed", + "remediation.failed", ] for event in v: if event not in valid_events: raise ValueError(f"Invalid event type: {event}. Must be one of: {valid_events}") return v - - @validator('url') + + @validator("url") def validate_url(cls, v): - if v and not v.startswith(('http://', 'https://')): - raise ValueError('URL must start with http:// or https://') + if v and not v.startswith(("http://", "https://")): + raise ValueError("URL must start with http:// or https://") return v @@ -85,25 +86,25 @@ async def list_webhook_endpoints( limit: int = 50, offset: int = 0, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """List webhook endpoints with optional filtering""" try: # Build query conditions where_conditions = [] params = {"limit": limit, "offset": offset} - + if is_active is not None: where_conditions.append("is_active = :is_active") params["is_active"] = is_active - + if event_type: # Use JSON contains operator for PostgreSQL where_conditions.append("event_types::jsonb ? :event_type") params["event_type"] = event_type - + where_clause = "WHERE " + " AND ".join(where_conditions) if where_conditions else "" - + query = f""" SELECT id, name, url, event_types, is_active, created_by, created_at, updated_at FROM webhook_endpoints @@ -111,23 +112,27 @@ async def list_webhook_endpoints( ORDER BY created_at DESC LIMIT :limit OFFSET :offset """ - + result = db.execute(text(query), params) - + webhooks = [] for row in result: webhook_data = { "id": str(row.id), "name": row.name, "url": row.url, - "event_types": json.loads(row.event_types) if isinstance(row.event_types, str) else row.event_types, + "event_types": ( + json.loads(row.event_types) + if isinstance(row.event_types, str) + else row.event_types + ), "is_active": row.is_active, "created_by": row.created_by, "created_at": row.created_at.isoformat() if row.created_at else None, - "updated_at": row.updated_at.isoformat() if row.updated_at else None + "updated_at": row.updated_at.isoformat() if row.updated_at else None, } webhooks.append(webhook_data) - + # Get total count count_query = f""" SELECT COUNT(*) as total @@ -135,14 +140,9 @@ async def list_webhook_endpoints( {where_clause} """ total_result = db.execute(text(count_query), params).fetchone() - - return { - "webhooks": webhooks, - "total": total_result.total, - "limit": limit, - "offset": offset - } - + + return {"webhooks": webhooks, "total": total_result.total, "limit": limit, "offset": offset} + except Exception as e: logger.error(f"Error listing webhook endpoints: {e}") raise HTTPException(status_code=500, detail="Failed to retrieve webhook endpoints") @@ -152,45 +152,50 @@ async def list_webhook_endpoints( async def create_webhook_endpoint( webhook_request: WebhookEndpointCreate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Create a new webhook endpoint""" try: # Hash the secret for secure storage secret_hash = hashlib.sha256(webhook_request.secret.encode()).hexdigest() - + # Create webhook endpoint record - result = db.execute(text(""" + result = db.execute( + text( + """ INSERT INTO webhook_endpoints (id, name, url, event_types, secret_hash, is_active, created_by, created_at, updated_at) VALUES (:id, :name, :url, :event_types, :secret_hash, :is_active, :created_by, :created_at, :updated_at) RETURNING id - """), { - "id": str(uuid.uuid4()), - "name": webhook_request.name, - "url": webhook_request.url, - "event_types": json.dumps(webhook_request.event_types), - "secret_hash": secret_hash, - "is_active": True, - "created_by": current_user["id"], - "created_at": datetime.utcnow(), - "updated_at": datetime.utcnow() - }) - + """ + ), + { + "id": str(uuid.uuid4()), + "name": webhook_request.name, + "url": webhook_request.url, + "event_types": json.dumps(webhook_request.event_types), + "secret_hash": secret_hash, + "is_active": True, + "created_by": current_user["id"], + "created_at": datetime.utcnow(), + "updated_at": datetime.utcnow(), + }, + ) + webhook_id = result.fetchone().id db.commit() - + logger.info(f"Webhook endpoint created: {webhook_id}") - + return { "id": webhook_id, "name": webhook_request.name, "url": webhook_request.url, "event_types": webhook_request.event_types, "is_active": True, - "message": "Webhook endpoint created successfully" + "message": "Webhook endpoint created successfully", } - + except Exception as e: logger.error(f"Error creating webhook endpoint: {e}") raise HTTPException(status_code=500, detail="Failed to create webhook endpoint") @@ -198,31 +203,38 @@ async def create_webhook_endpoint( @router.get("/{webhook_id}") async def get_webhook_endpoint( - webhook_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + webhook_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Get webhook endpoint details""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, name, url, event_types, is_active, created_by, created_at, updated_at FROM webhook_endpoints WHERE id = :id - """), {"id": webhook_id}).fetchone() - + """ + ), + {"id": webhook_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="Webhook endpoint not found") - + return { "id": str(result.id), "name": result.name, "url": result.url, - "event_types": json.loads(result.event_types) if isinstance(result.event_types, str) else result.event_types, + "event_types": ( + json.loads(result.event_types) + if isinstance(result.event_types, str) + else result.event_types + ), "is_active": result.is_active, "created_by": result.created_by, "created_at": result.created_at.isoformat() if result.created_at else None, - "updated_at": result.updated_at.isoformat() if result.updated_at else None + "updated_at": result.updated_at.isoformat() if result.updated_at else None, } - + except HTTPException: raise except Exception as e: @@ -235,62 +247,67 @@ async def update_webhook_endpoint( webhook_id: str, webhook_update: WebhookEndpointUpdate, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Update webhook endpoint""" try: # Check if webhook exists - existing = db.execute(text(""" + existing = db.execute( + text( + """ SELECT id FROM webhook_endpoints WHERE id = :id - """), {"id": webhook_id}).fetchone() - + """ + ), + {"id": webhook_id}, + ).fetchone() + if not existing: raise HTTPException(status_code=404, detail="Webhook endpoint not found") - + # Build update query with secure column mapping updates = [] params = {"id": webhook_id, "updated_at": datetime.utcnow()} - + # Security Fix: Use explicit column mapping instead of f-string concatenation allowed_updates = { "name": "name = :name", - "url": "url = :url", + "url": "url = :url", "event_types": "event_types = :event_types", "is_active": "is_active = :is_active", "secret": "secret_hash = :secret_hash", - "updated_at": "updated_at = :updated_at" + "updated_at": "updated_at = :updated_at", } - + if webhook_update.name is not None: updates.append(allowed_updates["name"]) params["name"] = webhook_update.name - + if webhook_update.url is not None: updates.append(allowed_updates["url"]) params["url"] = webhook_update.url - + if webhook_update.event_types is not None: updates.append(allowed_updates["event_types"]) params["event_types"] = json.dumps(webhook_update.event_types) - + if webhook_update.is_active is not None: updates.append(allowed_updates["is_active"]) params["is_active"] = webhook_update.is_active - + if webhook_update.secret is not None: updates.append(allowed_updates["secret"]) params["secret_hash"] = hashlib.sha256(webhook_update.secret.encode()).hexdigest() - + updates.append(allowed_updates["updated_at"]) - + if updates: # Security Fix: Use safe string concatenation instead of f-string query = "UPDATE webhook_endpoints SET " + ", ".join(updates) + " WHERE id = :id" db.execute(text(query), params) db.commit() - + return {"message": "Webhook endpoint updated successfully"} - + except HTTPException: raise except Exception as e: @@ -300,35 +317,48 @@ async def update_webhook_endpoint( @router.delete("/{webhook_id}") async def delete_webhook_endpoint( - webhook_id: str, - db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + webhook_id: str, db: Session = Depends(get_db), current_user: dict = Depends(get_current_user) ): """Delete webhook endpoint and its delivery history""" try: # Check if webhook exists - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id FROM webhook_endpoints WHERE id = :id - """), {"id": webhook_id}).fetchone() - + """ + ), + {"id": webhook_id}, + ).fetchone() + if not result: raise HTTPException(status_code=404, detail="Webhook endpoint not found") - + # Delete webhook deliveries first (foreign key constraint) - db.execute(text(""" + db.execute( + text( + """ DELETE FROM webhook_deliveries WHERE webhook_id = :webhook_id - """), {"webhook_id": webhook_id}) - + """ + ), + {"webhook_id": webhook_id}, + ) + # Delete webhook endpoint - db.execute(text(""" + db.execute( + text( + """ DELETE FROM webhook_endpoints WHERE id = :id - """), {"id": webhook_id}) - + """ + ), + {"id": webhook_id}, + ) + db.commit() - + logger.info(f"Webhook endpoint deleted: {webhook_id}") return {"message": "Webhook endpoint deleted successfully"} - + except HTTPException: raise except Exception as e: @@ -343,28 +373,33 @@ async def get_webhook_deliveries( limit: int = 50, offset: int = 0, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Get webhook delivery history""" try: # Verify webhook exists - webhook_result = db.execute(text(""" + webhook_result = db.execute( + text( + """ SELECT id FROM webhook_endpoints WHERE id = :id - """), {"id": webhook_id}).fetchone() - + """ + ), + {"id": webhook_id}, + ).fetchone() + if not webhook_result: raise HTTPException(status_code=404, detail="Webhook endpoint not found") - + # Build query conditions where_conditions = ["webhook_id = :webhook_id"] params = {"webhook_id": webhook_id, "limit": limit, "offset": offset} - + if delivery_status: where_conditions.append("delivery_status = :delivery_status") params["delivery_status"] = delivery_status - + where_clause = "WHERE " + " AND ".join(where_conditions) - + query = f""" SELECT id, event_type, event_data, delivery_status, http_status_code, response_body, error_message, created_at, delivered_at @@ -373,24 +408,28 @@ async def get_webhook_deliveries( ORDER BY created_at DESC LIMIT :limit OFFSET :offset """ - + result = db.execute(text(query), params) - + deliveries = [] for row in result: delivery_data = { "id": str(row.id), "event_type": row.event_type, - "event_data": json.loads(row.event_data) if isinstance(row.event_data, str) else row.event_data, + "event_data": ( + json.loads(row.event_data) + if isinstance(row.event_data, str) + else row.event_data + ), "delivery_status": row.delivery_status, "http_status_code": row.http_status_code, "response_body": row.response_body, "error_message": row.error_message, "created_at": row.created_at.isoformat() if row.created_at else None, - "delivered_at": row.delivered_at.isoformat() if row.delivered_at else None + "delivered_at": row.delivered_at.isoformat() if row.delivered_at else None, } deliveries.append(delivery_data) - + # Get total count count_query = f""" SELECT COUNT(*) as total @@ -398,14 +437,14 @@ async def get_webhook_deliveries( {where_clause} """ total_result = db.execute(text(count_query), params).fetchone() - + return { "deliveries": deliveries, "total": total_result.total, "limit": limit, - "offset": offset + "offset": offset, } - + except HTTPException: raise except Exception as e: @@ -418,19 +457,24 @@ async def test_webhook_endpoint( webhook_id: str, background_tasks: BackgroundTasks, db: Session = Depends(get_db), - current_user: dict = Depends(get_current_user) + current_user: dict = Depends(get_current_user), ): """Send a test webhook to verify connectivity""" try: # Get webhook details - webhook_result = db.execute(text(""" + webhook_result = db.execute( + text( + """ SELECT id, name, url, secret_hash FROM webhook_endpoints WHERE id = :id AND is_active = true - """), {"id": webhook_id}).fetchone() - + """ + ), + {"id": webhook_id}, + ).fetchone() + if not webhook_result: raise HTTPException(status_code=404, detail="Webhook endpoint not found or inactive") - + # Create test event data test_event = { "event_type": "test.webhook", @@ -438,28 +482,25 @@ async def test_webhook_endpoint( "webhook_id": webhook_id, "test_data": { "message": "This is a test webhook delivery", - "triggered_by": current_user.get("username", "system") - } + "triggered_by": current_user.get("username", "system"), + }, } - + # Queue webhook delivery as background task from ..tasks.webhook_tasks import deliver_webhook # Import here to avoid circular imports + background_tasks.add_task( - deliver_webhook, - webhook_result.url, - webhook_result.secret_hash, - test_event, - webhook_id + deliver_webhook, webhook_result.url, webhook_result.secret_hash, test_event, webhook_id ) - + return { "message": "Test webhook queued for delivery", "webhook_id": webhook_id, - "url": webhook_result.url + "url": webhook_result.url, } - + except HTTPException: raise except Exception as e: logger.error(f"Error testing webhook endpoint: {e}") - raise HTTPException(status_code=500, detail="Failed to test webhook endpoint") \ No newline at end of file + raise HTTPException(status_code=500, detail="Failed to test webhook endpoint") diff --git a/backend/app/services/__init__.py b/backend/app/services/__init__.py index c2ada237..1ac52eef 100644 --- a/backend/app/services/__init__.py +++ b/backend/app/services/__init__.py @@ -1 +1 @@ -# OpenWatch Services Module \ No newline at end of file +# OpenWatch Services Module diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index 53cb1cc6..2a2af8fe 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -3,6 +3,7 @@ Provides unified credential storage, encryption, and validation for OpenWatch. Replaces the dual-system approach with a single, consistent authentication layer. """ + import uuid import json import base64 @@ -24,6 +25,7 @@ class CredentialScope(str, Enum): """Credential scope types""" + SYSTEM = "system" HOST = "host" GROUP = "group" @@ -31,6 +33,7 @@ class CredentialScope(str, Enum): class AuthMethod(str, Enum): """Authentication method types""" + SSH_KEY = "ssh_key" PASSWORD = "password" BOTH = "both" @@ -38,6 +41,7 @@ class AuthMethod(str, Enum): class CredentialData(BaseModel): """Unified credential data structure""" + username: str auth_method: AuthMethod private_key: Optional[str] = None @@ -48,6 +52,7 @@ class CredentialData(BaseModel): class CredentialMetadata(BaseModel): """Credential metadata for storage""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) name: str description: Optional[str] = None @@ -62,24 +67,25 @@ class CentralizedAuthService: Centralized authentication service that provides unified credential management. Solves the issue where system credentials use AES encryption but host credentials only use base64. """ - + def __init__(self, db: Session): self.db = db - - def store_credential(self, credential_data: CredentialData, metadata: CredentialMetadata, - created_by: str) -> str: + + def store_credential( + self, credential_data: CredentialData, metadata: CredentialMetadata, created_by: str + ) -> str: """ Store credential with unified encryption and validation. All credentials use AES-256-GCM regardless of scope. - + Args: credential_data: The credential information to store metadata: Metadata about the credential (scope, target, etc.) created_by: User ID who is creating the credential - + Returns: str: The credential ID - + Raises: ValueError: If credential validation fails Exception: If storage fails @@ -89,33 +95,36 @@ def store_credential(self, credential_data: CredentialData, metadata: Credential validation_result = self.validate_credential(credential_data) if not validation_result[0]: raise ValueError(f"Credential validation failed: {validation_result[1]}") - + # Extract SSH key metadata if provided ssh_metadata = {} if credential_data.private_key: - ssh_metadata = self._extract_ssh_key_metadata(credential_data.private_key, - credential_data.private_key_passphrase) - + ssh_metadata = self._extract_ssh_key_metadata( + credential_data.private_key, credential_data.private_key_passphrase + ) + # If setting as default, unset other defaults in same scope if metadata.is_default: self._unset_default_credentials(metadata.scope, metadata.target_id) - + # Encrypt sensitive data using unified AES-256-GCM encrypted_password = None encrypted_private_key = None encrypted_passphrase = None - + if credential_data.password: encrypted_password = encrypt_data(credential_data.password.encode()) if credential_data.private_key: encrypted_private_key = encrypt_data(credential_data.private_key.encode()) if credential_data.private_key_passphrase: encrypted_passphrase = encrypt_data(credential_data.private_key_passphrase.encode()) - + # Store in unified credentials table current_time = datetime.utcnow() - - self.db.execute(text(""" + + self.db.execute( + text( + """ INSERT INTO unified_credentials (id, name, description, scope, target_id, username, auth_method, encrypted_password, encrypted_private_key, encrypted_passphrase, @@ -125,65 +134,75 @@ def store_credential(self, credential_data: CredentialData, metadata: Credential :encrypted_password, :encrypted_private_key, :encrypted_passphrase, :ssh_key_fingerprint, :ssh_key_type, :ssh_key_bits, :ssh_key_comment, :is_default, :is_active, :created_by, :created_at, :updated_at) - """), { - "id": metadata.id, - "name": metadata.name, - "description": metadata.description, - "scope": metadata.scope.value, - "target_id": metadata.target_id, - "username": credential_data.username, - "auth_method": credential_data.auth_method.value, - "encrypted_password": encrypted_password, - "encrypted_private_key": encrypted_private_key, - "encrypted_passphrase": encrypted_passphrase, - "ssh_key_fingerprint": ssh_metadata.get('fingerprint'), - "ssh_key_type": ssh_metadata.get('key_type'), - "ssh_key_bits": ssh_metadata.get('key_bits'), - "ssh_key_comment": ssh_metadata.get('key_comment'), - "is_default": metadata.is_default, - "is_active": metadata.is_active, - "created_by": created_by, - "created_at": current_time, - "updated_at": current_time - }) - + """ + ), + { + "id": metadata.id, + "name": metadata.name, + "description": metadata.description, + "scope": metadata.scope.value, + "target_id": metadata.target_id, + "username": credential_data.username, + "auth_method": credential_data.auth_method.value, + "encrypted_password": encrypted_password, + "encrypted_private_key": encrypted_private_key, + "encrypted_passphrase": encrypted_passphrase, + "ssh_key_fingerprint": ssh_metadata.get("fingerprint"), + "ssh_key_type": ssh_metadata.get("key_type"), + "ssh_key_bits": ssh_metadata.get("key_bits"), + "ssh_key_comment": ssh_metadata.get("key_comment"), + "is_default": metadata.is_default, + "is_active": metadata.is_active, + "created_by": created_by, + "created_at": current_time, + "updated_at": current_time, + }, + ) + self.db.commit() - - logger.info(f"Stored {metadata.scope.value} credential '{metadata.name}' (ID: {metadata.id})") + + logger.info( + f"Stored {metadata.scope.value} credential '{metadata.name}' (ID: {metadata.id})" + ) return metadata.id - + except Exception as e: logger.error(f"Failed to store credential: {e}") self.db.rollback() raise - + def get_credential(self, credential_id: str) -> Optional[CredentialData]: """ Retrieve and decrypt a specific credential by ID. - + Args: credential_id: The credential ID to retrieve - + Returns: CredentialData: The decrypted credential data, or None if not found """ try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT username, auth_method, encrypted_password, encrypted_private_key, encrypted_passphrase, scope, target_id FROM unified_credentials WHERE id = :id AND is_active = true - """), {"id": credential_id}) - + """ + ), + {"id": credential_id}, + ) + row = result.fetchone() if not row: return None - + # Decrypt credential data password = None private_key = None passphrase = None - + if row.encrypted_password: # Handle both string and memoryview from database encrypted_data = row.encrypted_password @@ -191,12 +210,13 @@ def get_credential(self, credential_id: str) -> Optional[CredentialData]: # memoryview contains base64-encoded bytes - decode then decrypt import base64 from .encryption import get_encryption_service + decoded_bytes = base64.b64decode(bytes(encrypted_data)) password = get_encryption_service().decrypt(decoded_bytes).decode() else: # String data is base64 encoded - use decrypt_data password = decrypt_data(encrypted_data).decode() - + if row.encrypted_private_key: # Handle both string and memoryview from database encrypted_data = row.encrypted_private_key @@ -204,12 +224,13 @@ def get_credential(self, credential_id: str) -> Optional[CredentialData]: # memoryview contains base64-encoded bytes - decode then decrypt import base64 from .encryption import get_encryption_service + decoded_bytes = base64.b64decode(bytes(encrypted_data)) private_key = get_encryption_service().decrypt(decoded_bytes).decode() else: # String data is base64 encoded - use decrypt_data private_key = decrypt_data(encrypted_data).decode() - + if row.encrypted_passphrase: # Handle both string and memoryview from database encrypted_data = row.encrypted_passphrase @@ -217,84 +238,94 @@ def get_credential(self, credential_id: str) -> Optional[CredentialData]: # memoryview contains base64-encoded bytes - decode then decrypt import base64 from .encryption import get_encryption_service + decoded_bytes = base64.b64decode(bytes(encrypted_data)) passphrase = get_encryption_service().decrypt(decoded_bytes).decode() else: # String data is base64 encoded - use decrypt_data passphrase = decrypt_data(encrypted_data).decode() - + return CredentialData( username=row.username, auth_method=AuthMethod(row.auth_method), password=password, private_key=private_key, private_key_passphrase=passphrase, - source=f"{row.scope}:{row.target_id}" if row.target_id else row.scope + source=f"{row.scope}:{row.target_id}" if row.target_id else row.scope, ) - + except Exception as e: import traceback + logger.error(f"Failed to get credential {credential_id}: {e}") logger.error(f"Traceback: {traceback.format_exc()}") return None - - def resolve_credential(self, target_id: str = None, use_default: bool = False) -> Optional[CredentialData]: + + def resolve_credential( + self, target_id: str = None, use_default: bool = False + ) -> Optional[CredentialData]: """ Resolve effective credentials using inheritance logic. TEMPORARY: Using legacy system_credentials table until unified migration is complete. - + Resolution order: 1. If use_default=True -> legacy system default credential 2. If target_id provided -> target-specific credential (not implemented yet) 3. If target has no credential -> fallback to legacy system default - + Args: target_id: Target ID (host_id, group_id) to resolve credentials for use_default: Force use of system default credentials - + Returns: CredentialData: Resolved credential, or None if none available """ try: # Use unified credentials system (migration is now complete) - + if use_default or not target_id: logger.info(f"Using unified_credentials table for credential resolution") return self._get_system_default() - + # For now, host-specific credentials are not supported via unified system # Fall back to legacy system default - logger.info(f"No host-specific unified credentials supported yet, using legacy system default") + logger.info( + f"No host-specific unified credentials supported yet, using legacy system default" + ) return self._get_legacy_system_default() - + except Exception as e: logger.error(f"Failed to resolve credential: {e}") return None - + def _get_legacy_system_default(self) -> Optional[CredentialData]: """Get system default credential from legacy system_credentials table""" try: logger.info("Getting legacy system default credential from system_credentials table") - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT id, username, auth_method, encrypted_password, encrypted_private_key, private_key_passphrase FROM system_credentials WHERE is_default = true AND is_active = true LIMIT 1 - """)) - + """ + ) + ) + row = result.fetchone() if row: logger.info("Found legacy system default credential, decrypting...") # Import decryption function for legacy credentials from .encryption import decrypt_data import base64 - + # Decrypt legacy credential data password = None private_key = None passphrase = None - + if row.encrypted_password: try: encrypted_data = row.encrypted_password @@ -302,6 +333,7 @@ def _get_legacy_system_default(self) -> Optional[CredentialData]: # memoryview contains base64-encoded bytes - decode then decrypt import base64 from .encryption import get_encryption_service + decoded_bytes = base64.b64decode(bytes(encrypted_data)) password = get_encryption_service().decrypt(decoded_bytes).decode() else: @@ -310,7 +342,7 @@ def _get_legacy_system_default(self) -> Optional[CredentialData]: logger.info("Successfully decrypted legacy password") except Exception as e: logger.warning(f"Failed to decrypt legacy password: {e}") - + if row.encrypted_private_key: try: encrypted_data = row.encrypted_private_key @@ -318,6 +350,7 @@ def _get_legacy_system_default(self) -> Optional[CredentialData]: # memoryview contains base64-encoded bytes - decode then decrypt import base64 from .encryption import get_encryption_service + decoded_bytes = base64.b64decode(bytes(encrypted_data)) private_key = get_encryption_service().decrypt(decoded_bytes).decode() else: @@ -326,7 +359,7 @@ def _get_legacy_system_default(self) -> Optional[CredentialData]: logger.info("Successfully decrypted legacy private key") except Exception as e: logger.warning(f"Failed to decrypt legacy private key: {e}") - + if row.private_key_passphrase: try: encrypted_data = row.private_key_passphrase @@ -334,6 +367,7 @@ def _get_legacy_system_default(self) -> Optional[CredentialData]: # memoryview contains base64-encoded bytes - decode then decrypt import base64 from .encryption import get_encryption_service + decoded_bytes = base64.b64decode(bytes(encrypted_data)) passphrase = get_encryption_service().decrypt(decoded_bytes).decode() else: @@ -342,65 +376,77 @@ def _get_legacy_system_default(self) -> Optional[CredentialData]: logger.info("Successfully decrypted legacy passphrase") except Exception as e: logger.warning(f"Failed to decrypt legacy passphrase: {e}") - + credential = CredentialData( username=row.username, auth_method=AuthMethod(row.auth_method), password=password, private_key=private_key, private_key_passphrase=passphrase, - source="legacy_system_default" + source="legacy_system_default", + ) + + logger.info( + f"Successfully resolved legacy system default credential for user: {row.username}" ) - - logger.info(f"Successfully resolved legacy system default credential for user: {row.username}") return credential - + logger.warning("No legacy system default credential found in system_credentials table") return None - + except Exception as e: logger.error(f"Failed to get legacy system default credential: {e}") return None - + def _get_system_default(self) -> Optional[CredentialData]: """Get system default credential with fallback to legacy system_credentials table""" try: # First try unified_credentials table (new system) - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT id FROM unified_credentials WHERE scope = 'system' AND is_default = true AND is_active = true LIMIT 1 - """)) - + """ + ) + ) + row = result.fetchone() if row: credential = self.get_credential(row.id) if credential: credential.source = "system_default" return credential - + # Fallback to legacy system_credentials table - logger.warning("No unified system credentials found, checking legacy system_credentials table") - result = self.db.execute(text(""" + logger.warning( + "No unified system credentials found, checking legacy system_credentials table" + ) + result = self.db.execute( + text( + """ SELECT id, username, auth_method, encrypted_password, encrypted_private_key, private_key_passphrase FROM system_credentials WHERE is_default = true AND is_active = true LIMIT 1 - """)) - + """ + ) + ) + row = result.fetchone() if row: logger.warning("Found legacy system default credential, using it") # Import decryption function for legacy credentials from .encryption import decrypt_data import base64 - + # Decrypt legacy credential data password = None private_key = None passphrase = None - + if row.encrypted_password: try: encrypted_data = row.encrypted_password @@ -408,6 +454,7 @@ def _get_system_default(self) -> Optional[CredentialData]: # memoryview contains base64-encoded bytes - decode then decrypt import base64 from .encryption import get_encryption_service + decoded_bytes = base64.b64decode(bytes(encrypted_data)) password = get_encryption_service().decrypt(decoded_bytes).decode() else: @@ -415,7 +462,7 @@ def _get_system_default(self) -> Optional[CredentialData]: password = decrypt_data(encrypted_data).decode() except Exception as e: logger.warning(f"Failed to decrypt legacy password: {e}") - + if row.encrypted_private_key: try: encrypted_data = row.encrypted_private_key @@ -423,6 +470,7 @@ def _get_system_default(self) -> Optional[CredentialData]: # memoryview contains base64-encoded bytes - decode then decrypt import base64 from .encryption import get_encryption_service + decoded_bytes = base64.b64decode(bytes(encrypted_data)) private_key = get_encryption_service().decrypt(decoded_bytes).decode() else: @@ -430,7 +478,7 @@ def _get_system_default(self) -> Optional[CredentialData]: private_key = decrypt_data(encrypted_data).decode() except Exception as e: logger.warning(f"Failed to decrypt legacy private key: {e}") - + if row.private_key_passphrase: try: encrypted_data = row.private_key_passphrase @@ -438,6 +486,7 @@ def _get_system_default(self) -> Optional[CredentialData]: # memoryview contains base64-encoded bytes - decode then decrypt import base64 from .encryption import get_encryption_service + decoded_bytes = base64.b64decode(bytes(encrypted_data)) passphrase = get_encryption_service().decrypt(decoded_bytes).decode() else: @@ -445,32 +494,35 @@ def _get_system_default(self) -> Optional[CredentialData]: passphrase = decrypt_data(encrypted_data).decode() except Exception as e: logger.warning(f"Failed to decrypt legacy passphrase: {e}") - + return CredentialData( username=row.username, auth_method=AuthMethod(row.auth_method), password=password, private_key=private_key, private_key_passphrase=passphrase, - source="legacy_system_default" + source="legacy_system_default", ) - - logger.warning("No system default credential found in either unified_credentials or system_credentials") + + logger.warning( + "No system default credential found in either unified_credentials or system_credentials" + ) return None - + except Exception as e: logger.error(f"Failed to get system default credential: {e}") return None - - def validate_credential(self, credential_data: CredentialData, - strict_mode: bool = True) -> Tuple[bool, str]: + + def validate_credential( + self, credential_data: CredentialData, strict_mode: bool = True + ) -> Tuple[bool, str]: """ Validate credential data with strict security policy enforcement. - + Args: credential_data: The credential to validate strict_mode: Whether to enforce strict security policies (default: True) - + Returns: Tuple[bool, str]: (is_valid, error_message) """ @@ -478,15 +530,15 @@ def validate_credential(self, credential_data: CredentialData, # Basic format validation if not credential_data.username: return False, "Username is required" - + if credential_data.auth_method in [AuthMethod.PASSWORD, AuthMethod.BOTH]: if not credential_data.password: return False, "Password is required for password authentication" - + if credential_data.auth_method in [AuthMethod.SSH_KEY, AuthMethod.BOTH]: if not credential_data.private_key: return False, "SSH private key is required for key authentication" - + # Use strict validation by default (Security Fix 4) if strict_mode: policy_level = SecurityPolicyLevel.STRICT @@ -495,11 +547,13 @@ def validate_credential(self, credential_data: CredentialData, auth_method=credential_data.auth_method.value, private_key=credential_data.private_key, password=credential_data.password, - policy_level=policy_level + policy_level=policy_level, ) - + if not is_valid: - logger.warning(f"Credential rejected by strict security policy: {error_message}") + logger.warning( + f"Credential rejected by strict security policy: {error_message}" + ) return False, error_message else: # Legacy validation (only for compatibility) @@ -507,128 +561,150 @@ def validate_credential(self, credential_data: CredentialData, validation_result = validate_ssh_key(credential_data.private_key) if not validation_result.is_valid: return False, f"Invalid SSH key: {validation_result.error_message}" - + return True, "" - + except Exception as e: logger.error(f"Credential validation error: {e}") return False, f"Validation error: {str(e)}" - + def _extract_ssh_key_metadata(self, private_key: str, passphrase: str = None) -> Dict: """Extract SSH key metadata for storage""" try: metadata = extract_ssh_key_metadata(private_key, passphrase) return { - 'fingerprint': metadata.get('fingerprint'), - 'key_type': metadata.get('key_type'), - 'key_bits': int(metadata.get('key_bits')) if metadata.get('key_bits') else None, - 'key_comment': metadata.get('key_comment') + "fingerprint": metadata.get("fingerprint"), + "key_type": metadata.get("key_type"), + "key_bits": int(metadata.get("key_bits")) if metadata.get("key_bits") else None, + "key_comment": metadata.get("key_comment"), } except Exception as e: logger.warning(f"Failed to extract SSH key metadata: {e}") return {} - + def _unset_default_credentials(self, scope: CredentialScope, target_id: str = None): """Unset existing default credentials in the same scope""" try: if scope == CredentialScope.SYSTEM: - self.db.execute(text(""" + self.db.execute( + text( + """ UPDATE unified_credentials SET is_default = false WHERE scope = 'system' AND is_default = true - """)) + """ + ) + ) else: - self.db.execute(text(""" + self.db.execute( + text( + """ UPDATE unified_credentials SET is_default = false WHERE scope = :scope AND target_id = :target_id AND is_default = true - """), {"scope": scope.value, "target_id": target_id}) - + """ + ), + {"scope": scope.value, "target_id": target_id}, + ) + except Exception as e: logger.error(f"Failed to unset default credentials: {e}") - - def list_credentials(self, scope: CredentialScope = None, target_id: str = None, - user_id: str = None) -> List[Dict]: + + def list_credentials( + self, scope: CredentialScope = None, target_id: str = None, user_id: str = None + ) -> List[Dict]: """ List credentials with filtering options. - + Args: scope: Filter by credential scope target_id: Filter by target ID user_id: Filter by user (for access control) - + Returns: List[Dict]: List of credential metadata (no sensitive data) """ try: conditions = ["is_active = true"] params = {} - + if scope: conditions.append("scope = :scope") params["scope"] = scope.value - + if target_id: conditions.append("target_id = :target_id") params["target_id"] = target_id - + if user_id: conditions.append("created_by = :user_id") params["user_id"] = user_id - + where_clause = " AND ".join(conditions) - - result = self.db.execute(text(f""" + + result = self.db.execute( + text( + f""" SELECT id, name, description, scope, target_id, username, auth_method, ssh_key_fingerprint, ssh_key_type, ssh_key_bits, ssh_key_comment, is_default, created_at, updated_at FROM unified_credentials WHERE {where_clause} ORDER BY scope, is_default DESC, name - """), params) - + """ + ), + params, + ) + credentials = [] for row in result: - credentials.append({ - "id": row.id, - "name": row.name, - "description": row.description, - "scope": row.scope, - "target_id": row.target_id, - "username": row.username, - "auth_method": row.auth_method, - "ssh_key_fingerprint": row.ssh_key_fingerprint, - "ssh_key_type": row.ssh_key_type, - "ssh_key_bits": row.ssh_key_bits, - "ssh_key_comment": row.ssh_key_comment, - "is_default": row.is_default, - "created_at": row.created_at.isoformat(), - "updated_at": row.updated_at.isoformat() - }) - + credentials.append( + { + "id": row.id, + "name": row.name, + "description": row.description, + "scope": row.scope, + "target_id": row.target_id, + "username": row.username, + "auth_method": row.auth_method, + "ssh_key_fingerprint": row.ssh_key_fingerprint, + "ssh_key_type": row.ssh_key_type, + "ssh_key_bits": row.ssh_key_bits, + "ssh_key_comment": row.ssh_key_comment, + "is_default": row.is_default, + "created_at": row.created_at.isoformat(), + "updated_at": row.updated_at.isoformat(), + } + ) + return credentials - + except Exception as e: logger.error(f"Failed to list credentials: {e}") return [] - + def delete_credential(self, credential_id: str) -> bool: """ Soft delete a credential by marking it inactive. - + Args: credential_id: The credential ID to delete - + Returns: bool: True if successful, False otherwise """ try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ UPDATE unified_credentials SET is_active = false, updated_at = :updated_at WHERE id = :id - """), {"id": credential_id, "updated_at": datetime.utcnow()}) - + """ + ), + {"id": credential_id, "updated_at": datetime.utcnow()}, + ) + if result.rowcount > 0: self.db.commit() logger.info(f"Deleted credential {credential_id}") @@ -636,7 +712,7 @@ def delete_credential(self, credential_id: str) -> bool: else: logger.warning(f"Credential {credential_id} not found for deletion") return False - + except Exception as e: logger.error(f"Failed to delete credential {credential_id}: {e}") self.db.rollback() @@ -646,4 +722,4 @@ def delete_credential(self, credential_id: str) -> bool: # Factory function for service creation def get_auth_service(db: Session) -> CentralizedAuthService: """Factory function to create CentralizedAuthService instance""" - return CentralizedAuthService(db) \ No newline at end of file + return CentralizedAuthService(db) diff --git a/backend/app/services/authorization_service.py b/backend/app/services/authorization_service.py index 1341bf8b..df4221a0 100644 --- a/backend/app/services/authorization_service.py +++ b/backend/app/services/authorization_service.py @@ -16,6 +16,7 @@ Design by Emily (Security Engineer) & Implementation by Daniel (Backend Engineer) """ + import logging import asyncio import time @@ -26,11 +27,22 @@ from concurrent.futures import ThreadPoolExecutor from ..models.authorization_models import ( - ResourceType, ActionType, PermissionEffect, PermissionScope, - AuthorizationDecision, ResourceIdentifier, AuthorizationContext, - AuthorizationResult, BulkAuthorizationRequest, BulkAuthorizationResult, - HostPermission, HostGroupPermission, AuthorizationAuditEvent, - PolicyConflictResolution, AuthorizationConfiguration, PermissionCache + ResourceType, + ActionType, + PermissionEffect, + PermissionScope, + AuthorizationDecision, + ResourceIdentifier, + AuthorizationContext, + AuthorizationResult, + BulkAuthorizationRequest, + BulkAuthorizationResult, + HostPermission, + HostGroupPermission, + AuthorizationAuditEvent, + PolicyConflictResolution, + AuthorizationConfiguration, + PermissionCache, ) from ..rbac import UserRole, RBACManager, Permission @@ -40,132 +52,142 @@ class AuthorizationService: """ Core authorization service implementing Zero Trust principles - + SECURITY FEATURES: 1. Resource-Based Access Control (ReBAC) - Permission validation per resource - 2. Zero Trust Architecture - Verify at every operation boundary + 2. Zero Trust Architecture - Verify at every operation boundary 3. Authorization Audit Trail - Complete logging of all access decisions 4. Least Privilege Enforcement - Users access only explicitly permitted resources 5. Cross-Host Validation - Prevents privilege escalation through bulk operations """ - + def __init__(self, db: Session, config: Optional[AuthorizationConfiguration] = None): self.db = db self.config = config or AuthorizationConfiguration() self.permission_cache = PermissionCache( - ttl_seconds=self.config.cache_ttl_seconds, - max_size=self.config.max_cache_size + ttl_seconds=self.config.cache_ttl_seconds, max_size=self.config.max_cache_size ) self.executor = ThreadPoolExecutor(max_workers=10) - + logger.info(f"Authorization service initialized with config: {self.config}") - + async def check_permission( self, user_id: str, resource: ResourceIdentifier, action: ActionType, - context: Optional[AuthorizationContext] = None + context: Optional[AuthorizationContext] = None, ) -> AuthorizationResult: """ Check if a user has permission to perform an action on a resource. - + ZERO TRUST PRINCIPLE: Every resource access must be explicitly validated. No assumptions or inheritance without verification. - + Args: user_id: User requesting access - resource: Resource being accessed + resource: Resource being accessed action: Action being performed context: Additional context for decision - + Returns: AuthorizationResult: Detailed authorization decision """ start_time = time.time() - + try: # Create default context if not provided if context is None: context = await self._build_user_context(user_id) - + # Check cache first for performance if self.config.cache_ttl_seconds > 0: cached_result = self.permission_cache.get(user_id, resource, action) if cached_result: - logger.debug(f"Cache hit for {user_id}:{resource.resource_type}:{resource.resource_id}:{action}") + logger.debug( + f"Cache hit for {user_id}:{resource.resource_type}:{resource.resource_id}:{action}" + ) return cached_result - + # Perform authorization evaluation result = await self._evaluate_permission(user_id, resource, action, context) - + # Calculate evaluation time evaluation_time = int((time.time() - start_time) * 1000) result.evaluation_time_ms = evaluation_time - + # Cache positive results if caching enabled if self.config.cache_ttl_seconds > 0 and result.decision == AuthorizationDecision.ALLOW: self.permission_cache.put(user_id, resource, action, result) - + # Audit log the decision if self.config.enable_audit_logging: await self._audit_authorization_decision(result, context) - + # Log security-relevant decisions if result.decision == AuthorizationDecision.DENY: - logger.warning(f"ACCESS DENIED: User {user_id} denied {action} on {resource.resource_type}:{resource.resource_id} - {result.reason}") + logger.warning( + f"ACCESS DENIED: User {user_id} denied {action} on {resource.resource_type}:{resource.resource_id} - {result.reason}" + ) else: - logger.debug(f"ACCESS GRANTED: User {user_id} allowed {action} on {resource.resource_type}:{resource.resource_id}") - + logger.debug( + f"ACCESS GRANTED: User {user_id} allowed {action} on {resource.resource_type}:{resource.resource_id}" + ) + return result - + except Exception as e: logger.error(f"Authorization check failed for user {user_id}: {e}") - + # Fail securely - deny access on error return AuthorizationResult( decision=AuthorizationDecision.DENY, resource=resource, action=action, - context=context or AuthorizationContext(user_id=user_id, user_roles=[], user_groups=[]), + context=context + or AuthorizationContext(user_id=user_id, user_roles=[], user_groups=[]), applied_policies=[], reason=f"Authorization system error: {str(e)}", confidence_score=0.0, - evaluation_time_ms=int((time.time() - start_time) * 1000) + evaluation_time_ms=int((time.time() - start_time) * 1000), ) - + async def check_bulk_permissions( - self, - request: BulkAuthorizationRequest + self, request: BulkAuthorizationRequest ) -> BulkAuthorizationResult: """ Check permissions for multiple resources in bulk. - + CRITICAL SECURITY FIX: This method prevents the vulnerability where bulk operations bypass per-host authorization checks. - + Each resource is individually validated to ensure users cannot access systems outside their permission scope. - + Args: request: Bulk authorization request - + Returns: BulkAuthorizationResult: Results for all resources """ start_time = time.time() - - logger.info(f"Bulk authorization check for user {request.user_id}: {len(request.resources)} resources, action {request.action}") - + + logger.info( + f"Bulk authorization check for user {request.user_id}: {len(request.resources)} resources, action {request.action}" + ) + try: individual_results = [] denied_resources = [] allowed_resources = [] cached_count = 0 fresh_count = 0 - + # Process resources based on configuration - if request.parallel_evaluation and len(request.resources) >= self.config.parallel_evaluation_threshold: + if ( + request.parallel_evaluation + and len(request.resources) >= self.config.parallel_evaluation_threshold + ): # Parallel evaluation for large requests individual_results = await self._evaluate_parallel_permissions( request.user_id, request.resources, request.action, request.context @@ -177,47 +199,59 @@ async def check_bulk_permissions( request.user_id, resource, request.action, request.context ) individual_results.append(result) - + if result.cached: cached_count += 1 else: fresh_count += 1 - + # Fail fast if configured and we hit a deny if request.fail_fast and result.decision == AuthorizationDecision.DENY: - logger.info(f"Fail-fast triggered: Access denied for resource {resource.resource_id}") + logger.info( + f"Fail-fast triggered: Access denied for resource {resource.resource_id}" + ) # Still need to create placeholder results for remaining resources - remaining_resources = request.resources[len(individual_results):] + remaining_resources = request.resources[len(individual_results) :] for remaining_resource in remaining_resources: - individual_results.append(AuthorizationResult( - decision=AuthorizationDecision.DENY, - resource=remaining_resource, - action=request.action, - context=request.context, - applied_policies=[], - reason="Bulk operation failed fast on previous denial", - confidence_score=1.0 - )) + individual_results.append( + AuthorizationResult( + decision=AuthorizationDecision.DENY, + resource=remaining_resource, + action=request.action, + context=request.context, + applied_policies=[], + reason="Bulk operation failed fast on previous denial", + confidence_score=1.0, + ) + ) break - + # Categorize results for result in individual_results: if result.decision == AuthorizationDecision.ALLOW: allowed_resources.append(result.resource) else: denied_resources.append(result.resource) - + # Determine overall decision - overall_decision = AuthorizationDecision.ALLOW if len(denied_resources) == 0 else AuthorizationDecision.DENY - + overall_decision = ( + AuthorizationDecision.ALLOW + if len(denied_resources) == 0 + else AuthorizationDecision.DENY + ) + total_time = int((time.time() - start_time) * 1000) - + # Audit bulk authorization attempt if self.config.enable_audit_logging: await self._audit_bulk_authorization( - request, overall_decision, len(allowed_resources), len(denied_resources), total_time + request, + overall_decision, + len(allowed_resources), + len(denied_resources), + total_time, ) - + result = BulkAuthorizationResult( overall_decision=overall_decision, individual_results=individual_results, @@ -225,18 +259,20 @@ async def check_bulk_permissions( allowed_resources=allowed_resources, total_evaluation_time_ms=total_time, cached_results=cached_count, - fresh_evaluations=fresh_count + fresh_evaluations=fresh_count, + ) + + logger.info( + f"Bulk authorization completed: {overall_decision.value} " + f"({len(allowed_resources)} allowed, {len(denied_resources)} denied) " + f"in {total_time}ms" ) - - logger.info(f"Bulk authorization completed: {overall_decision.value} " - f"({len(allowed_resources)} allowed, {len(denied_resources)} denied) " - f"in {total_time}ms") - + return result - + except Exception as e: logger.error(f"Bulk authorization failed: {e}") - + # Fail securely return BulkAuthorizationResult( overall_decision=AuthorizationDecision.DENY, @@ -248,7 +284,7 @@ async def check_bulk_permissions( context=request.context, applied_policies=[], reason=f"Bulk authorization system error: {str(e)}", - confidence_score=0.0 + confidence_score=0.0, ) for resource in request.resources ], @@ -256,21 +292,21 @@ async def check_bulk_permissions( allowed_resources=[], total_evaluation_time_ms=int((time.time() - start_time) * 1000), cached_results=0, - fresh_evaluations=0 + fresh_evaluations=0, ) - + async def _evaluate_permission( self, user_id: str, resource: ResourceIdentifier, action: ActionType, - context: AuthorizationContext + context: AuthorizationContext, ) -> AuthorizationResult: """ Core permission evaluation logic implementing Zero Trust principles """ applied_policies = [] - + try: # Step 1: Check if user exists and is active user_valid = await self._validate_user(user_id) @@ -281,25 +317,26 @@ async def _evaluate_permission( action=action, context=context, applied_policies=[], - reason="User not found or inactive" + reason="User not found or inactive", ) - + # Step 2: Get all applicable policies for this request policies = await self._get_applicable_policies(user_id, resource, action, context) - + # Step 3: Evaluate policies using conflict resolution strategy decision, reason = self._evaluate_policies(policies) applied_policies = policies - + # Step 4: Apply role-based permissions as additional validation - role_decision = await self._evaluate_role_permissions(user_id, resource, action, context) - + role_decision = await self._evaluate_role_permissions( + user_id, resource, action, context + ) + # Step 5: Combine policy and role decisions final_decision, final_reason = self._combine_decisions( - policy_decision=(decision, reason), - role_decision=role_decision + policy_decision=(decision, reason), role_decision=role_decision ) - + return AuthorizationResult( decision=final_decision, resource=resource, @@ -307,12 +344,12 @@ async def _evaluate_permission( context=context, applied_policies=applied_policies, reason=final_reason, - confidence_score=1.0 + confidence_score=1.0, ) - + except Exception as e: logger.error(f"Permission evaluation error: {e}") - + return AuthorizationResult( decision=AuthorizationDecision.DENY, resource=resource, @@ -320,22 +357,23 @@ async def _evaluate_permission( context=context, applied_policies=applied_policies, reason=f"Evaluation error: {str(e)}", - confidence_score=0.0 + confidence_score=0.0, ) - - async def _get_applicable_policies( + + def _get_applicable_policies( self, user_id: str, resource: ResourceIdentifier, action: ActionType, - context: AuthorizationContext + context: AuthorizationContext, ) -> List[Dict]: """ Get all policies that apply to this permission check """ try: # Build query to find applicable policies - query = text(""" + query = text( + """ SELECT hp.id, hp.user_id, hp.group_id, hp.role_name, hp.host_id, hp.actions, hp.effect, hp.conditions, hp.granted_by, @@ -387,82 +425,105 @@ async def _get_applicable_policies( ) ORDER BY granted_at DESC - """) - + """ + ) + # Convert user groups and roles to tuples for SQL IN clause user_groups = tuple(context.user_groups) if context.user_groups else (None,) user_roles = tuple(context.user_roles) if context.user_roles else (None,) - - result = self.db.execute(query, { - 'user_id': user_id, - 'resource_id': resource.resource_id, - 'user_groups': user_groups, - 'user_roles': user_roles, - 'now': datetime.utcnow() - }) - + + result = self.db.execute( + query, + { + "user_id": user_id, + "resource_id": resource.resource_id, + "user_groups": user_groups, + "user_roles": user_roles, + "now": datetime.utcnow(), + }, + ) + policies = [] for row in result: # Parse actions from JSON/string format try: import json - actions = json.loads(row.actions) if isinstance(row.actions, str) else row.actions + + actions = ( + json.loads(row.actions) if isinstance(row.actions, str) else row.actions + ) except: actions = [row.actions] if row.actions else [] - + # Check if this policy applies to the requested action - if action.value in actions or 'all' in actions: - policies.append({ - 'id': row.id, - 'user_id': row.user_id, - 'group_id': row.group_id, - 'role_name': row.role_name, - 'resource_id': row.host_id, - 'actions': actions, - 'effect': row.effect, - 'conditions': row.conditions, - 'policy_type': row.policy_type, - 'granted_by': row.granted_by, - 'granted_at': row.granted_at - }) - - logger.debug(f"Found {len(policies)} applicable policies for user {user_id} on resource {resource.resource_id}") + if action.value in actions or "all" in actions: + policies.append( + { + "id": row.id, + "user_id": row.user_id, + "group_id": row.group_id, + "role_name": row.role_name, + "resource_id": row.host_id, + "actions": actions, + "effect": row.effect, + "conditions": row.conditions, + "policy_type": row.policy_type, + "granted_by": row.granted_by, + "granted_at": row.granted_at, + } + ) + + logger.debug( + f"Found {len(policies)} applicable policies for user {user_id} on resource {resource.resource_id}" + ) return policies - + except Exception as e: logger.error(f"Error getting applicable policies: {e}") return [] - + def _evaluate_policies(self, policies: List[Dict]) -> Tuple[AuthorizationDecision, str]: """ Evaluate policies based on conflict resolution strategy """ if not policies: return AuthorizationDecision.DENY, "No applicable policies found" - - allow_policies = [p for p in policies if p['effect'] == 'allow'] - deny_policies = [p for p in policies if p['effect'] == 'deny'] - + + allow_policies = [p for p in policies if p["effect"] == "allow"] + deny_policies = [p for p in policies if p["effect"] == "deny"] + if self.config.conflict_resolution == PolicyConflictResolution.DENY_OVERRIDES: if deny_policies: - return AuthorizationDecision.DENY, f"Access explicitly denied by {len(deny_policies)} deny policies" + return ( + AuthorizationDecision.DENY, + f"Access explicitly denied by {len(deny_policies)} deny policies", + ) elif allow_policies: - return AuthorizationDecision.ALLOW, f"Access granted by {len(allow_policies)} allow policies" - + return ( + AuthorizationDecision.ALLOW, + f"Access granted by {len(allow_policies)} allow policies", + ) + elif self.config.conflict_resolution == PolicyConflictResolution.ALLOW_OVERRIDES: if allow_policies: - return AuthorizationDecision.ALLOW, f"Access granted by {len(allow_policies)} allow policies" + return ( + AuthorizationDecision.ALLOW, + f"Access granted by {len(allow_policies)} allow policies", + ) elif deny_policies: - return AuthorizationDecision.DENY, f"Access denied by {len(deny_policies)} deny policies" - + return ( + AuthorizationDecision.DENY, + f"Access denied by {len(deny_policies)} deny policies", + ) + return self.config.default_decision, "Applied default decision" - - async def _evaluate_role_permissions( + + def _evaluate_role_permissions( self, user_id: str, resource: ResourceIdentifier, action: ActionType, - context: AuthorizationContext + context: AuthorizationContext, ) -> Tuple[AuthorizationDecision, str]: """ Evaluate role-based permissions as additional validation layer @@ -475,59 +536,73 @@ async def _evaluate_role_permissions( ActionType.EXECUTE: Permission.SCAN_EXECUTE, ActionType.WRITE: Permission.HOST_UPDATE, ActionType.DELETE: Permission.HOST_DELETE, - ActionType.MANAGE: Permission.HOST_MANAGE_ACCESS + ActionType.MANAGE: Permission.HOST_MANAGE_ACCESS, } - + required_permission = action_permission_map.get(action) if not required_permission: return AuthorizationDecision.DENY, f"No role permission mapping for action {action}" - + # Check if any user role has the required permission for role_name in context.user_roles: try: user_role = UserRole(role_name) if RBACManager.has_permission(user_role, required_permission): - return AuthorizationDecision.ALLOW, f"Role {role_name} has permission {required_permission.value}" + return ( + AuthorizationDecision.ALLOW, + f"Role {role_name} has permission {required_permission.value}", + ) except ValueError: logger.warning(f"Unknown role: {role_name}") continue - - return AuthorizationDecision.DENY, f"No user role has permission {required_permission.value}" - + + return ( + AuthorizationDecision.DENY, + f"No user role has permission {required_permission.value}", + ) + except Exception as e: logger.error(f"Role permission evaluation error: {e}") return AuthorizationDecision.DENY, f"Role evaluation error: {str(e)}" - + def _combine_decisions( self, policy_decision: Tuple[AuthorizationDecision, str], - role_decision: Tuple[AuthorizationDecision, str] + role_decision: Tuple[AuthorizationDecision, str], ) -> Tuple[AuthorizationDecision, str]: """ Combine policy-based and role-based authorization decisions """ policy_allow = policy_decision[0] == AuthorizationDecision.ALLOW role_allow = role_decision[0] == AuthorizationDecision.ALLOW - + # Both must allow for final allow decision (Zero Trust principle) if policy_allow and role_allow: - return AuthorizationDecision.ALLOW, f"Both policy and role checks passed: {policy_decision[1]} AND {role_decision[1]}" - + return ( + AuthorizationDecision.ALLOW, + f"Both policy and role checks passed: {policy_decision[1]} AND {role_decision[1]}", + ) + # If either denies, deny access if not policy_allow and not role_allow: - return AuthorizationDecision.DENY, f"Both policy and role checks failed: {policy_decision[1]} AND {role_decision[1]}" + return ( + AuthorizationDecision.DENY, + f"Both policy and role checks failed: {policy_decision[1]} AND {role_decision[1]}", + ) elif not policy_allow: return AuthorizationDecision.DENY, f"Policy check failed: {policy_decision[1]}" else: return AuthorizationDecision.DENY, f"Role check failed: {role_decision[1]}" - - async def _build_user_context(self, user_id: str) -> AuthorizationContext: + + def _build_user_context(self, user_id: str) -> AuthorizationContext: """ Build authorization context for a user """ try: # Get user information including roles and groups - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT u.id, u.username, u.role, COALESCE( JSON_AGG(DISTINCT ug.name) FILTER (WHERE ug.name IS NOT NULL), @@ -538,52 +613,49 @@ async def _build_user_context(self, user_id: str) -> AuthorizationContext: LEFT JOIN user_groups ug ON ugm.group_id = ug.id WHERE u.id = :user_id AND u.is_active = true GROUP BY u.id, u.username, u.role - """), {"user_id": user_id}) - + """ + ), + {"user_id": user_id}, + ) + row = result.fetchone() if not row: - return AuthorizationContext( - user_id=user_id, - user_roles=[], - user_groups=[] - ) - + return AuthorizationContext(user_id=user_id, user_roles=[], user_groups=[]) + import json + user_groups = json.loads(row.user_groups) if row.user_groups else [] - + return AuthorizationContext( - user_id=user_id, - user_roles=[row.role] if row.role else [], - user_groups=user_groups + user_id=user_id, user_roles=[row.role] if row.role else [], user_groups=user_groups ) - + except Exception as e: logger.error(f"Error building user context for {user_id}: {e}") - return AuthorizationContext( - user_id=user_id, - user_roles=[], - user_groups=[] - ) - - async def _validate_user(self, user_id: str) -> bool: + return AuthorizationContext(user_id=user_id, user_roles=[], user_groups=[]) + + def _validate_user(self, user_id: str) -> bool: """ Validate user exists and is active """ try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT id FROM users WHERE id = :user_id AND is_active = true - """), {"user_id": user_id}) - + """ + ), + {"user_id": user_id}, + ) + return result.fetchone() is not None - + except Exception as e: logger.error(f"User validation error for {user_id}: {e}") return False - - async def _audit_authorization_decision( - self, - result: AuthorizationResult, - context: AuthorizationContext + + def _audit_authorization_decision( + self, result: AuthorizationResult, context: AuthorizationContext ): """ Audit authorization decisions for security monitoring @@ -596,24 +668,26 @@ async def _audit_authorization_decision( resource_id=result.resource.resource_id, action=result.action, decision=result.decision, - policies_evaluated=[p.get('id', 'unknown') for p in result.applied_policies], + policies_evaluated=[p.get("id", "unknown") for p in result.applied_policies], context={ - 'user_roles': context.user_roles, - 'user_groups': context.user_groups, - 'ip_address': context.ip_address, - 'user_agent': context.user_agent, - 'session_id': context.session_id + "user_roles": context.user_roles, + "user_groups": context.user_groups, + "ip_address": context.ip_address, + "user_agent": context.user_agent, + "session_id": context.session_id, }, ip_address=context.ip_address, user_agent=context.user_agent, session_id=context.session_id, evaluation_time_ms=result.evaluation_time_ms, reason=result.reason, - risk_score=self._calculate_risk_score(result) + risk_score=self._calculate_risk_score(result), ) - + # Store audit event in database - self.db.execute(text(""" + self.db.execute( + text( + """ INSERT INTO authorization_audit_log (id, event_type, user_id, resource_type, resource_id, action, decision, policies_evaluated, context, ip_address, user_agent, session_id, @@ -621,37 +695,40 @@ async def _audit_authorization_decision( VALUES (:id, :event_type, :user_id, :resource_type, :resource_id, :action, :decision, :policies_evaluated, :context, :ip_address, :user_agent, :session_id, :evaluation_time_ms, :reason, :risk_score, :timestamp) - """), { - 'id': audit_event.id, - 'event_type': audit_event.event_type, - 'user_id': audit_event.user_id, - 'resource_type': audit_event.resource_type.value, - 'resource_id': audit_event.resource_id, - 'action': audit_event.action.value, - 'decision': audit_event.decision.value, - 'policies_evaluated': ','.join(audit_event.policies_evaluated), - 'context': str(audit_event.context), - 'ip_address': audit_event.ip_address, - 'user_agent': audit_event.user_agent, - 'session_id': audit_event.session_id, - 'evaluation_time_ms': audit_event.evaluation_time_ms, - 'reason': audit_event.reason, - 'risk_score': audit_event.risk_score, - 'timestamp': audit_event.timestamp - }) - + """ + ), + { + "id": audit_event.id, + "event_type": audit_event.event_type, + "user_id": audit_event.user_id, + "resource_type": audit_event.resource_type.value, + "resource_id": audit_event.resource_id, + "action": audit_event.action.value, + "decision": audit_event.decision.value, + "policies_evaluated": ",".join(audit_event.policies_evaluated), + "context": str(audit_event.context), + "ip_address": audit_event.ip_address, + "user_agent": audit_event.user_agent, + "session_id": audit_event.session_id, + "evaluation_time_ms": audit_event.evaluation_time_ms, + "reason": audit_event.reason, + "risk_score": audit_event.risk_score, + "timestamp": audit_event.timestamp, + }, + ) + self.db.commit() - + except Exception as e: logger.error(f"Failed to audit authorization decision: {e}") - - async def _audit_bulk_authorization( + + def _audit_bulk_authorization( self, request: BulkAuthorizationRequest, decision: AuthorizationDecision, allowed_count: int, denied_count: int, - evaluation_time_ms: int + evaluation_time_ms: int, ): """ Audit bulk authorization attempts @@ -666,22 +743,24 @@ async def _audit_bulk_authorization( decision=decision, policies_evaluated=[], context={ - 'resource_count': len(request.resources), - 'allowed_count': allowed_count, - 'denied_count': denied_count, - 'fail_fast': request.fail_fast, - 'parallel_evaluation': request.parallel_evaluation + "resource_count": len(request.resources), + "allowed_count": allowed_count, + "denied_count": denied_count, + "fail_fast": request.fail_fast, + "parallel_evaluation": request.parallel_evaluation, }, ip_address=request.context.ip_address, user_agent=request.context.user_agent, session_id=request.context.session_id, evaluation_time_ms=evaluation_time_ms, reason=f"Bulk authorization: {allowed_count} allowed, {denied_count} denied", - risk_score=self._calculate_bulk_risk_score(denied_count, len(request.resources)) + risk_score=self._calculate_bulk_risk_score(denied_count, len(request.resources)), ) - + # Store bulk audit event - self.db.execute(text(""" + self.db.execute( + text( + """ INSERT INTO authorization_audit_log (id, event_type, user_id, resource_type, resource_id, action, decision, policies_evaluated, context, ip_address, user_agent, session_id, @@ -689,66 +768,69 @@ async def _audit_bulk_authorization( VALUES (:id, :event_type, :user_id, :resource_type, :resource_id, :action, :decision, :policies_evaluated, :context, :ip_address, :user_agent, :session_id, :evaluation_time_ms, :reason, :risk_score, :timestamp) - """), { - 'id': audit_event.id, - 'event_type': audit_event.event_type, - 'user_id': audit_event.user_id, - 'resource_type': audit_event.resource_type.value, - 'resource_id': audit_event.resource_id, - 'action': audit_event.action.value, - 'decision': audit_event.decision.value, - 'policies_evaluated': '', - 'context': str(audit_event.context), - 'ip_address': audit_event.ip_address, - 'user_agent': audit_event.user_agent, - 'session_id': audit_event.session_id, - 'evaluation_time_ms': audit_event.evaluation_time_ms, - 'reason': audit_event.reason, - 'risk_score': audit_event.risk_score, - 'timestamp': audit_event.timestamp - }) - + """ + ), + { + "id": audit_event.id, + "event_type": audit_event.event_type, + "user_id": audit_event.user_id, + "resource_type": audit_event.resource_type.value, + "resource_id": audit_event.resource_id, + "action": audit_event.action.value, + "decision": audit_event.decision.value, + "policies_evaluated": "", + "context": str(audit_event.context), + "ip_address": audit_event.ip_address, + "user_agent": audit_event.user_agent, + "session_id": audit_event.session_id, + "evaluation_time_ms": audit_event.evaluation_time_ms, + "reason": audit_event.reason, + "risk_score": audit_event.risk_score, + "timestamp": audit_event.timestamp, + }, + ) + self.db.commit() - + except Exception as e: logger.error(f"Failed to audit bulk authorization: {e}") - + def _calculate_risk_score(self, result: AuthorizationResult) -> float: """ Calculate risk score for authorization decision """ if not self.config.enable_risk_scoring: return 0.0 - + risk_score = 0.0 - + # Higher risk for denied access attempts if result.decision == AuthorizationDecision.DENY: risk_score += 0.3 - + # Higher risk for sensitive actions if result.action in [ActionType.DELETE, ActionType.MANAGE]: risk_score += 0.2 - + # Higher risk for system-level resources if result.resource.resource_type == ResourceType.SYSTEM: risk_score += 0.3 - + # Higher risk for long evaluation times (possible attack) if result.evaluation_time_ms > 500: risk_score += 0.2 - + return min(1.0, risk_score) - + def _calculate_bulk_risk_score(self, denied_count: int, total_count: int) -> float: """ Calculate risk score for bulk operations """ if total_count == 0: return 0.0 - + denial_ratio = denied_count / total_count - + # High denial ratio indicates possible unauthorized access attempt if denial_ratio > 0.5: return 0.8 @@ -758,13 +840,13 @@ def _calculate_bulk_risk_score(self, denied_count: int, total_count: int) -> flo return 0.3 else: return 0.1 - + async def _evaluate_parallel_permissions( self, user_id: str, resources: List[ResourceIdentifier], action: ActionType, - context: AuthorizationContext + context: AuthorizationContext, ) -> List[AuthorizationResult]: """ Evaluate permissions for multiple resources in parallel @@ -774,27 +856,29 @@ async def _evaluate_parallel_permissions( for resource in resources: task = self.check_permission(user_id, resource, action, context) tasks.append(task) - + results = await asyncio.gather(*tasks, return_exceptions=True) - + valid_results = [] for i, result in enumerate(results): if isinstance(result, Exception): # Handle exceptions by creating deny result - valid_results.append(AuthorizationResult( - decision=AuthorizationDecision.DENY, - resource=resources[i], - action=action, - context=context, - applied_policies=[], - reason=f"Parallel evaluation error: {str(result)}", - confidence_score=0.0 - )) + valid_results.append( + AuthorizationResult( + decision=AuthorizationDecision.DENY, + resource=resources[i], + action=action, + context=context, + applied_policies=[], + reason=f"Parallel evaluation error: {str(result)}", + confidence_score=0.0, + ) + ) else: valid_results.append(result) - + return valid_results - + except Exception as e: logger.error(f"Parallel permission evaluation failed: {e}") # Return deny results for all resources @@ -806,14 +890,14 @@ async def _evaluate_parallel_permissions( context=context, applied_policies=[], reason=f"Parallel evaluation system error: {str(e)}", - confidence_score=0.0 + confidence_score=0.0, ) for resource in resources ] - + # Permission Management Methods - - async def grant_host_permission( + + def grant_host_permission( self, user_id: Optional[str], group_id: Optional[str], @@ -822,7 +906,7 @@ async def grant_host_permission( actions: Set[ActionType], granted_by: str, expires_at: Optional[datetime] = None, - conditions: Optional[Dict[str, Any]] = None + conditions: Optional[Dict[str, Any]] = None, ) -> str: """ Grant permission for a specific host @@ -836,70 +920,86 @@ async def grant_host_permission( actions=actions, granted_by=granted_by, expires_at=expires_at, - conditions=conditions or {} + conditions=conditions or {}, ) - + # Store in database import json - self.db.execute(text(""" + + self.db.execute( + text( + """ INSERT INTO host_permissions (id, user_id, group_id, role_name, host_id, actions, effect, conditions, granted_by, granted_at, expires_at, is_active) VALUES (:id, :user_id, :group_id, :role_name, :host_id, :actions, :effect, :conditions, :granted_by, :granted_at, :expires_at, :is_active) - """), { - 'id': permission.id, - 'user_id': permission.user_id, - 'group_id': permission.group_id, - 'role_name': permission.role_name, - 'host_id': permission.host_id, - 'actions': json.dumps(list(actions)), - 'effect': permission.effect.value, - 'conditions': json.dumps(permission.conditions), - 'granted_by': permission.granted_by, - 'granted_at': permission.granted_at, - 'expires_at': permission.expires_at, - 'is_active': permission.is_active - }) - + """ + ), + { + "id": permission.id, + "user_id": permission.user_id, + "group_id": permission.group_id, + "role_name": permission.role_name, + "host_id": permission.host_id, + "actions": json.dumps(list(actions)), + "effect": permission.effect.value, + "conditions": json.dumps(permission.conditions), + "granted_by": permission.granted_by, + "granted_at": permission.granted_at, + "expires_at": permission.expires_at, + "is_active": permission.is_active, + }, + ) + self.db.commit() - + # Invalidate cache for affected user/resource if user_id: self.permission_cache.invalidate_user(user_id) - + resource = ResourceIdentifier(ResourceType.HOST, host_id) self.permission_cache.invalidate_resource(resource) - + logger.info(f"Granted host permission {permission.id} for host {host_id}") return permission.id - + except Exception as e: logger.error(f"Failed to grant host permission: {e}") self.db.rollback() raise - - async def revoke_permission(self, permission_id: str) -> bool: + + def revoke_permission(self, permission_id: str) -> bool: """ Revoke a specific permission """ try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ UPDATE host_permissions SET is_active = false, updated_at = :now WHERE id = :permission_id - """), {'permission_id': permission_id, 'now': datetime.utcnow()}) - + """ + ), + {"permission_id": permission_id, "now": datetime.utcnow()}, + ) + if result.rowcount == 0: # Try host group permissions - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ UPDATE host_group_permissions SET is_active = false, updated_at = :now WHERE id = :permission_id - """), {'permission_id': permission_id, 'now': datetime.utcnow()}) - + """ + ), + {"permission_id": permission_id, "now": datetime.utcnow()}, + ) + self.db.commit() - + if result.rowcount > 0: # Clear entire cache since we don't know which users/resources were affected self.permission_cache.clear() @@ -908,7 +1008,7 @@ async def revoke_permission(self, permission_id: str) -> bool: else: logger.warning(f"Permission {permission_id} not found for revocation") return False - + except Exception as e: logger.error(f"Failed to revoke permission {permission_id}: {e}") self.db.rollback() @@ -916,6 +1016,8 @@ async def revoke_permission(self, permission_id: str) -> bool: # Factory function -def get_authorization_service(db: Session, config: Optional[AuthorizationConfiguration] = None) -> AuthorizationService: +def get_authorization_service( + db: Session, config: Optional[AuthorizationConfiguration] = None +) -> AuthorizationService: """Factory function to create AuthorizationService instance""" - return AuthorizationService(db, config) \ No newline at end of file + return AuthorizationService(db, config) diff --git a/backend/app/services/bulk_scan_orchestrator.py b/backend/app/services/bulk_scan_orchestrator.py index 45bc707a..2105f285 100644 --- a/backend/app/services/bulk_scan_orchestrator.py +++ b/backend/app/services/bulk_scan_orchestrator.py @@ -15,6 +15,7 @@ Design by Emily (Security Engineer) & Implementation by Daniel (Backend Engineer) """ + import logging import asyncio import uuid @@ -30,8 +31,12 @@ from .scan_intelligence import ScanIntelligenceService, HostInfo from .authorization_service import get_authorization_service from ..models.authorization_models import ( - ResourceType, ActionType, ResourceIdentifier, AuthorizationContext, - AuthorizationDecision, BulkAuthorizationRequest + ResourceType, + ActionType, + ResourceIdentifier, + AuthorizationContext, + AuthorizationDecision, + BulkAuthorizationRequest, ) from ..tasks.scan_tasks import execute_scan_task @@ -49,6 +54,7 @@ class ScanSessionStatus(Enum): @dataclass class ScanBatch: """A batch of scans to execute together""" + id: str hosts: List[HostInfo] content_id: int @@ -56,11 +62,12 @@ class ScanBatch: priority: int # Lower number = higher priority estimated_time: float # minutes max_parallel: int = 3 - - + + @dataclass class ScanSession: """Tracks a bulk scanning session with authorization metadata""" + id: str name: str total_hosts: int @@ -79,7 +86,7 @@ class ScanSession: authorized_hosts: int = 0 unauthorized_hosts: int = 0 authorization_failures: List[Dict] = None - + def __post_init__(self): if self.authorization_failures is None: self.authorization_failures = [] @@ -88,12 +95,13 @@ def __post_init__(self): @dataclass class AuthorizationFailure: """Represents an authorization failure for a specific host""" + host_id: str hostname: str reason: str user_id: str timestamp: datetime = None - + def __post_init__(self): if self.timestamp is None: self.timestamp = datetime.utcnow() @@ -102,37 +110,37 @@ def __post_init__(self): class BulkScanOrchestrator: """ Orchestrates bulk scanning operations across multiple hosts with authorization validation - + SECURITY ENHANCEMENTS: - Per-host authorization validation before scan creation - Comprehensive audit logging of authorization decisions - Separation of authorized and unauthorized hosts - Zero Trust implementation preventing privilege escalation """ - + def __init__(self, db: Session): self.db = db self.intelligence_service = ScanIntelligenceService(db) self.authorization_service = get_authorization_service(db) - + async def create_bulk_scan_session( - self, - host_ids: List[str], + self, + host_ids: List[str], template_id: str = "auto", name_prefix: str = "Bulk Scan", priority: str = "normal", user_id: str = None, stagger_delay: int = 30, - auth_context: Optional[AuthorizationContext] = None + auth_context: Optional[AuthorizationContext] = None, ) -> ScanSession: """ Create a bulk scan session with intelligent batching and per-host authorization validation - + CRITICAL SECURITY FIX: This method now validates user permissions for each host individually before creating any scans, preventing unauthorized access to systems outside the user's permission scope. - + Args: host_ids: List of host IDs to scan template_id: Scan template/profile to use @@ -141,44 +149,56 @@ async def create_bulk_scan_session( user_id: User creating the bulk scan stagger_delay: Delay between scan starts in seconds auth_context: Authorization context (will be built if not provided) - + Returns: ScanSession: Session with authorization results and only authorized scans - + Raises: ValueError: If no hosts are authorized for scanning PermissionError: If user lacks bulk scan permissions """ try: session_id = str(uuid.uuid4()) - logger.info(f"Creating bulk scan session {session_id} for {len(host_ids)} hosts by user {user_id}") - + logger.info( + f"Creating bulk scan session {session_id} for {len(host_ids)} hosts by user {user_id}" + ) + # SECURITY CHECK 1: Validate user exists and is active if not user_id: raise ValueError("User ID is required for bulk scan operations") - + # SECURITY CHECK 2: Per-host authorization validation logger.info(f"Performing authorization checks for {len(host_ids)} hosts") authorized_hosts, authorization_failures = await self._validate_bulk_scan_authorization( user_id, host_ids, auth_context ) - + if not authorized_hosts: - logger.warning(f"User {user_id} has no authorized hosts for bulk scan. {len(authorization_failures)} failures.") - raise PermissionError(f"No hosts authorized for scanning. Access denied to all {len(host_ids)} requested hosts.") - + logger.warning( + f"User {user_id} has no authorized hosts for bulk scan. {len(authorization_failures)} failures." + ) + raise PermissionError( + f"No hosts authorized for scanning. Access denied to all {len(host_ids)} requested hosts." + ) + # Log authorization summary - logger.info(f"Authorization complete: {len(authorized_hosts)} authorized, {len(authorization_failures)} denied") - + logger.info( + f"Authorization complete: {len(authorized_hosts)} authorized, {len(authorization_failures)} denied" + ) + # Filter host_ids to only include authorized hosts authorized_host_ids = [host.host_id for host in authorized_hosts] - + # Analyze bulk scan feasibility for authorized hosts only - feasibility = await self.intelligence_service.analyze_bulk_scan_feasibility(authorized_host_ids) + feasibility = await self.intelligence_service.analyze_bulk_scan_feasibility( + authorized_host_ids + ) if not feasibility["feasible"]: - logger.warning(f"Bulk scan not feasible for authorized hosts: {feasibility['reason']}") + logger.warning( + f"Bulk scan not feasible for authorized hosts: {feasibility['reason']}" + ) raise ValueError(f"Bulk scan not feasible: {feasibility['reason']}") - + # Create scan session record with authorization metadata session = ScanSession( id=session_id, @@ -191,23 +211,29 @@ async def create_bulk_scan_session( created_by=user_id, created_at=datetime.utcnow(), scan_ids=[], - estimated_completion=datetime.utcnow() + timedelta(minutes=feasibility.get("estimated_time_minutes", 60)), + estimated_completion=datetime.utcnow() + + timedelta(minutes=feasibility.get("estimated_time_minutes", 60)), authorized_hosts=len(authorized_hosts), unauthorized_hosts=len(authorization_failures), - authorization_failures=[{ - 'host_id': failure.host_id, - 'hostname': failure.hostname, - 'reason': failure.reason, - 'timestamp': failure.timestamp.isoformat() - } for failure in authorization_failures] + authorization_failures=[ + { + "host_id": failure.host_id, + "hostname": failure.hostname, + "reason": failure.reason, + "timestamp": failure.timestamp.isoformat(), + } + for failure in authorization_failures + ], ) - + # Store session in database with authorization metadata await self._store_scan_session(session) - + # Plan scan execution for authorized hosts only - scan_plan = await self._plan_bulk_scan(authorized_host_ids, template_id, session_id, priority) - + scan_plan = await self._plan_bulk_scan( + authorized_host_ids, template_id, session_id, priority + ) + # Store individual scans for authorized hosts scan_ids = [] for batch in scan_plan: @@ -216,25 +242,29 @@ async def create_bulk_scan_session( batch, session_id, name_prefix, user_id, stagger_delay, authorized_hosts ) scan_ids.extend(batch_scan_ids) - + # Update session with scan IDs session.scan_ids = scan_ids await self._update_scan_session(session) - + # Log final results if authorization_failures: - logger.warning(f"Bulk scan session {session_id} created with {len(scan_ids)} scans. " - f"Authorization denied for {len(authorization_failures)} hosts.") + logger.warning( + f"Bulk scan session {session_id} created with {len(scan_ids)} scans. " + f"Authorization denied for {len(authorization_failures)} hosts." + ) else: - logger.info(f"Bulk scan session {session_id} created successfully with {len(scan_ids)} scans. " - f"All {len(authorized_hosts)} hosts authorized.") - + logger.info( + f"Bulk scan session {session_id} created successfully with {len(scan_ids)} scans. " + f"All {len(authorized_hosts)} hosts authorized." + ) + return session - + except Exception as e: logger.error(f"Error creating bulk scan session: {e}") raise - + async def start_bulk_scan_session(self, session_id: str) -> Dict: """Start executing a bulk scan session""" try: @@ -242,30 +272,30 @@ async def start_bulk_scan_session(self, session_id: str) -> Dict: session = await self._get_scan_session(session_id) if not session: raise ValueError(f"Session {session_id} not found") - + if session.status != ScanSessionStatus.PENDING: raise ValueError(f"Session {session_id} is not in pending status") - + # Update session status session.status = ScanSessionStatus.RUNNING session.started_at = datetime.utcnow() await self._update_scan_session(session) - + # Start scans with staggered execution started_scans = await self._execute_staggered_scans(session.scan_ids) - + logger.info(f"Started bulk scan session {session_id} with {len(started_scans)} scans") return { "session_id": session_id, "status": "started", "started_scans": len(started_scans), - "total_scans": len(session.scan_ids) + "total_scans": len(session.scan_ids), } - + except Exception as e: logger.error(f"Error starting bulk scan session {session_id}: {e}") raise - + async def get_bulk_scan_progress(self, session_id: str) -> Dict: """Get real-time progress of a bulk scan session""" try: @@ -273,33 +303,33 @@ async def get_bulk_scan_progress(self, session_id: str) -> Dict: session = await self._get_scan_session(session_id) if not session: raise ValueError(f"Session {session_id} not found") - + # Get individual scan statuses scan_statuses = await self._get_scans_status(session.scan_ids) - + # Calculate progress metrics total_scans = len(session.scan_ids) completed = sum(1 for s in scan_statuses if s["status"] == "completed") failed = sum(1 for s in scan_statuses if s["status"] == "failed") running = sum(1 for s in scan_statuses if s["status"] in ["pending", "running"]) - + # Update session stats session.completed_hosts = completed session.failed_hosts = failed session.running_hosts = running - + # Determine overall session status if completed + failed == total_scans: session.status = ScanSessionStatus.COMPLETED session.completed_at = datetime.utcnow() elif failed > 0 and running == 0: session.status = ScanSessionStatus.FAILED - + await self._update_scan_session(session) - + # Calculate progress percentage progress_percent = int((completed / total_scans) * 100) if total_scans > 0 else 0 - + return { "session_id": session_id, "session_name": session.name, @@ -310,15 +340,21 @@ async def get_bulk_scan_progress(self, session_id: str) -> Dict: "failed_hosts": failed, "running_hosts": running, "started_at": session.started_at.isoformat() if session.started_at else None, - "estimated_completion": session.estimated_completion.isoformat() if session.estimated_completion else None, - "individual_scans": scan_statuses + "estimated_completion": ( + session.estimated_completion.isoformat() + if session.estimated_completion + else None + ), + "individual_scans": scan_statuses, } - + except Exception as e: logger.error(f"Error getting bulk scan progress for {session_id}: {e}") raise - - async def _plan_bulk_scan(self, host_ids: List[str], template_id: str, session_id: str, priority: str) -> List[ScanBatch]: + + async def _plan_bulk_scan( + self, host_ids: List[str], template_id: str, session_id: str, priority: str + ) -> List[ScanBatch]: """Plan the execution of bulk scans with intelligent batching""" try: # Get host information @@ -327,10 +363,10 @@ async def _plan_bulk_scan(self, host_ids: List[str], template_id: str, session_i host_info = await self.intelligence_service._get_host_info(host_id) if host_info: host_infos.append(host_info) - + if not host_infos: raise ValueError("No valid hosts found for bulk scan") - + # Group hosts by OS family for content optimization os_groups = {} for host in host_infos: @@ -338,29 +374,31 @@ async def _plan_bulk_scan(self, host_ids: List[str], template_id: str, session_i if os_family not in os_groups: os_groups[os_family] = [] os_groups[os_family].append(host) - + # Create scan batches batches = [] batch_priority = 1 - + for os_family, hosts in os_groups.items(): # Find best content and profile for this OS group - content_id, profile_id = await self._find_optimal_content_profile(hosts, template_id) - + content_id, profile_id = await self._find_optimal_content_profile( + hosts, template_id + ) + # Split large groups into smaller batches (max 10 hosts per batch) max_batch_size = 10 for i in range(0, len(hosts), max_batch_size): - batch_hosts = hosts[i:i + max_batch_size] - + batch_hosts = hosts[i : i + max_batch_size] + # Calculate estimated time based on host count and profile complexity estimated_time = len(batch_hosts) * 10 # Base 10 minutes per host - + # Adjust for profile complexity if "stig" in profile_id.lower(): estimated_time *= 1.5 # STIG scans take longer elif "cis" in profile_id.lower(): estimated_time *= 1.2 # CIS scans slightly longer - + batch = ScanBatch( id=str(uuid.uuid4()), hosts=batch_hosts, @@ -368,22 +406,22 @@ async def _plan_bulk_scan(self, host_ids: List[str], template_id: str, session_i profile_id=profile_id, priority=batch_priority, estimated_time=estimated_time, - max_parallel=min(3, len(batch_hosts)) + max_parallel=min(3, len(batch_hosts)), ) batches.append(batch) - + batch_priority += 1 - + # Sort batches by priority (production hosts first, then by estimated time) batches.sort(key=lambda b: (b.priority, b.estimated_time)) - + logger.info(f"Created {len(batches)} scan batches for session {session_id}") return batches - + except Exception as e: logger.error(f"Error planning bulk scan: {e}") raise - + def _extract_os_family(self, operating_system: str) -> str: """Extract OS family from operating system string""" os_lower = operating_system.lower() @@ -397,8 +435,10 @@ def _extract_os_family(self, operating_system: str) -> str: return "windows" else: return "unknown" - - async def _find_optimal_content_profile(self, hosts: List[HostInfo], template_id: str) -> Tuple[int, str]: + + async def _find_optimal_content_profile( + self, hosts: List[HostInfo], template_id: str + ) -> Tuple[int, str]: """Find the optimal SCAP content and profile for a group of hosts""" # For now, use the intelligence service to suggest for the first host # In a more sophisticated implementation, this would analyze all hosts @@ -407,91 +447,109 @@ async def _find_optimal_content_profile(self, hosts: List[HostInfo], template_id return suggestion.content_id, suggestion.profile_id else: # Use default content and specified template - return 1, template_id if template_id != "auto" else "xccdf_org.ssgproject.content_profile_cui" - - async def _create_batch_scans(self, batch: ScanBatch, session_id: str, name_prefix: str, user_id: str, stagger_delay: int) -> List[str]: + return 1, ( + template_id if template_id != "auto" else "xccdf_org.ssgproject.content_profile_cui" + ) + + def _create_batch_scans( + self, batch: ScanBatch, session_id: str, name_prefix: str, user_id: str, stagger_delay: int + ) -> List[str]: """Create individual scan records for a batch""" try: scan_ids = [] - + for i, host in enumerate(batch.hosts): scan_id = str(uuid.uuid4()) scan_name = f"{name_prefix} - {host.hostname}" - + # Calculate staggered start time start_delay = i * stagger_delay # seconds - + # Create scan record - self.db.execute(text(""" + self.db.execute( + text( + """ INSERT INTO scans (id, name, host_id, content_id, profile_id, status, progress, scan_options, started_by, started_at, remediation_requested, verification_scan) VALUES (:id, :name, :host_id, :content_id, :profile_id, :status, :progress, :scan_options, :started_by, :started_at, :remediation_requested, :verification_scan) - """), { - "id": scan_id, - "name": scan_name, - "host_id": host.id, - "content_id": batch.content_id, - "profile_id": batch.profile_id, - "status": "pending", - "progress": 0, - "scan_options": json.dumps({ - "bulk_scan": True, - "session_id": session_id, - "batch_id": batch.id, - "start_delay": start_delay - }), - "started_by": user_id, - "started_at": datetime.utcnow(), - "remediation_requested": False, - "verification_scan": False - }) - + """ + ), + { + "id": scan_id, + "name": scan_name, + "host_id": host.id, + "content_id": batch.content_id, + "profile_id": batch.profile_id, + "status": "pending", + "progress": 0, + "scan_options": json.dumps( + { + "bulk_scan": True, + "session_id": session_id, + "batch_id": batch.id, + "start_delay": start_delay, + } + ), + "started_by": user_id, + "started_at": datetime.utcnow(), + "remediation_requested": False, + "verification_scan": False, + }, + ) + scan_ids.append(scan_id) - + self.db.commit() return scan_ids - + except Exception as e: logger.error(f"Error creating batch scans: {e}") raise - - async def _store_scan_session(self, session: ScanSession): + + def _store_scan_session(self, session: ScanSession): """Store scan session in database""" try: # Create a scan sessions table record (you'll need to create this table) - self.db.execute(text(""" + self.db.execute( + text( + """ INSERT INTO scan_sessions (id, name, total_hosts, completed_hosts, failed_hosts, running_hosts, status, created_by, created_at, started_at, completed_at, estimated_completion, scan_ids, error_message) VALUES (:id, :name, :total_hosts, :completed_hosts, :failed_hosts, :running_hosts, :status, :created_by, :created_at, :started_at, :completed_at, :estimated_completion, :scan_ids, :error_message) - """), { - "id": session.id, - "name": session.name, - "total_hosts": session.total_hosts, - "completed_hosts": session.completed_hosts, - "failed_hosts": session.failed_hosts, - "running_hosts": session.running_hosts, - "status": session.status.value, - "created_by": session.created_by, - "created_at": session.created_at, - "started_at": session.started_at, - "completed_at": session.completed_at, - "estimated_completion": session.estimated_completion, - "scan_ids": json.dumps(session.scan_ids or []), - "error_message": session.error_message - }) + """ + ), + { + "id": session.id, + "name": session.name, + "total_hosts": session.total_hosts, + "completed_hosts": session.completed_hosts, + "failed_hosts": session.failed_hosts, + "running_hosts": session.running_hosts, + "status": session.status.value, + "created_by": session.created_by, + "created_at": session.created_at, + "started_at": session.started_at, + "completed_at": session.completed_at, + "estimated_completion": session.estimated_completion, + "scan_ids": json.dumps(session.scan_ids or []), + "error_message": session.error_message, + }, + ) self.db.commit() except Exception as e: logger.error(f"Error storing scan session: {e}") raise - - async def _update_scan_session(self, session: ScanSession): + + def _update_scan_session(self, session: ScanSession): """Update scan session in database""" try: - self.db.execute(text(""" + self.db.execute( + text( + """ UPDATE scan_sessions SET completed_hosts = :completed_hosts, failed_hosts = :failed_hosts, @@ -502,34 +560,42 @@ async def _update_scan_session(self, session: ScanSession): scan_ids = :scan_ids, error_message = :error_message WHERE id = :id - """), { - "id": session.id, - "completed_hosts": session.completed_hosts, - "failed_hosts": session.failed_hosts, - "running_hosts": session.running_hosts, - "status": session.status.value, - "started_at": session.started_at, - "completed_at": session.completed_at, - "scan_ids": json.dumps(session.scan_ids or []), - "error_message": session.error_message - }) + """ + ), + { + "id": session.id, + "completed_hosts": session.completed_hosts, + "failed_hosts": session.failed_hosts, + "running_hosts": session.running_hosts, + "status": session.status.value, + "started_at": session.started_at, + "completed_at": session.completed_at, + "scan_ids": json.dumps(session.scan_ids or []), + "error_message": session.error_message, + }, + ) self.db.commit() except Exception as e: logger.error(f"Error updating scan session: {e}") raise - - async def _get_scan_session(self, session_id: str) -> Optional[ScanSession]: + + def _get_scan_session(self, session_id: str) -> Optional[ScanSession]: """Retrieve scan session from database""" try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT id, name, total_hosts, completed_hosts, failed_hosts, running_hosts, status, created_by, created_at, started_at, completed_at, estimated_completion, scan_ids, error_message FROM scan_sessions WHERE id = :id - """), {"id": session_id}).fetchone() - + """ + ), + {"id": session_id}, + ).fetchone() + if not result: return None - + return ScanSession( id=result.id, name=result.name, @@ -544,22 +610,24 @@ async def _get_scan_session(self, session_id: str) -> Optional[ScanSession]: completed_at=result.completed_at, estimated_completion=result.estimated_completion, scan_ids=json.loads(result.scan_ids or "[]"), - error_message=result.error_message + error_message=result.error_message, ) except Exception as e: logger.error(f"Error getting scan session: {e}") return None - - async def _get_scans_status(self, scan_ids: List[str]) -> List[Dict]: + + def _get_scans_status(self, scan_ids: List[str]) -> List[Dict]: """Get status of multiple scans""" if not scan_ids: return [] - + try: # Create placeholders for the IN clause - placeholders = ','.join([f"'{scan_id}'" for scan_id in scan_ids]) - - result = self.db.execute(text(f""" + placeholders = ",".join([f"'{scan_id}'" for scan_id in scan_ids]) + + result = self.db.execute( + text( + f""" SELECT s.id, s.name, s.status, s.progress, s.started_at, s.completed_at, h.hostname, h.display_name, sr.score, sr.failed_rules, sr.total_rules @@ -568,78 +636,84 @@ async def _get_scans_status(self, scan_ids: List[str]) -> List[Dict]: LEFT JOIN scan_results sr ON sr.scan_id = s.id WHERE s.id IN ({placeholders}) ORDER BY s.started_at - """)).fetchall() - + """ + ) + ).fetchall() + scan_statuses = [] for row in result: - scan_statuses.append({ - "scan_id": row.id, - "scan_name": row.name, - "hostname": row.hostname, - "display_name": row.display_name, - "status": row.status, - "progress": row.progress, - "started_at": row.started_at.isoformat() if row.started_at else None, - "completed_at": row.completed_at.isoformat() if row.completed_at else None, - "compliance_score": row.score, - "failed_rules": row.failed_rules or 0, - "total_rules": row.total_rules or 0 - }) - + scan_statuses.append( + { + "scan_id": row.id, + "scan_name": row.name, + "hostname": row.hostname, + "display_name": row.display_name, + "status": row.status, + "progress": row.progress, + "started_at": row.started_at.isoformat() if row.started_at else None, + "completed_at": row.completed_at.isoformat() if row.completed_at else None, + "compliance_score": row.score, + "failed_rules": row.failed_rules or 0, + "total_rules": row.total_rules or 0, + } + ) + return scan_statuses - + except Exception as e: logger.error(f"Error getting scans status: {e}") return [] - - async def _execute_staggered_scans(self, scan_ids: List[str]) -> List[str]: + + def _execute_staggered_scans(self, scan_ids: List[str]) -> List[str]: """Execute scans with staggered start times""" # For now, just update all scans to running status # In a production system, this would integrate with Celery or similar try: if not scan_ids: return [] - - placeholders = ','.join([f"'{scan_id}'" for scan_id in scan_ids]) - + + placeholders = ",".join([f"'{scan_id}'" for scan_id in scan_ids]) + # Update scan status to running - self.db.execute(text(f""" + self.db.execute( + text( + f""" UPDATE scans SET status = 'running', started_at = :started_at WHERE id IN ({placeholders}) AND status = 'pending' - """), {"started_at": datetime.utcnow()}) - + """ + ), + {"started_at": datetime.utcnow()}, + ) + self.db.commit() - + # In a real implementation, you would start the actual scan tasks here # For now, return the scan IDs that were updated return scan_ids - + except Exception as e: logger.error(f"Error executing staggered scans: {e}") return [] - + # AUTHORIZATION METHODS - CRITICAL SECURITY IMPLEMENTATION - + async def _validate_bulk_scan_authorization( - self, - user_id: str, - host_ids: List[str], - auth_context: Optional[AuthorizationContext] = None - ) -> Tuple[List['AuthorizedHost'], List[AuthorizationFailure]]: + self, user_id: str, host_ids: List[str], auth_context: Optional[AuthorizationContext] = None + ) -> Tuple[List["AuthorizedHost"], List[AuthorizationFailure]]: """ Validate user authorization for each host in bulk scan request - + ZERO TRUST IMPLEMENTATION: - Individual validation for each host - No implicit permissions or inheritance - Comprehensive audit trail - Fail-secure behavior - + Args: user_id: User requesting bulk scan host_ids: List of host IDs to validate auth_context: Optional authorization context - + Returns: Tuple of (authorized_hosts, authorization_failures) """ @@ -647,16 +721,13 @@ async def _validate_bulk_scan_authorization( # Build authorization context if not provided if auth_context is None: auth_context = await self._build_user_authorization_context(user_id) - + # Create resource identifiers for all hosts resources = [ - ResourceIdentifier( - resource_type=ResourceType.HOST, - resource_id=host_id - ) + ResourceIdentifier(resource_type=ResourceType.HOST, resource_id=host_id) for host_id in host_ids ] - + # Perform bulk authorization check bulk_request = BulkAuthorizationRequest( user_id=user_id, @@ -664,74 +735,86 @@ async def _validate_bulk_scan_authorization( action=ActionType.SCAN, context=auth_context, fail_fast=False, # Check all hosts to provide complete results - parallel_evaluation=True # Enable parallel processing + parallel_evaluation=True, # Enable parallel processing ) - + logger.debug(f"Performing bulk authorization check for {len(resources)} hosts") auth_result = await self.authorization_service.check_bulk_permissions(bulk_request) - + # Get host details for results host_details = await self._get_host_details(host_ids) - host_lookup = {h['id']: h for h in host_details} - + host_lookup = {h["id"]: h for h in host_details} + # Process authorization results authorized_hosts = [] authorization_failures = [] - + for result in auth_result.individual_results: host_id = result.resource.resource_id - host_detail = host_lookup.get(host_id, {'hostname': 'unknown', 'display_name': 'unknown'}) - + host_detail = host_lookup.get( + host_id, {"hostname": "unknown", "display_name": "unknown"} + ) + if result.decision == AuthorizationDecision.ALLOW: - authorized_hosts.append(AuthorizedHost( - host_id=host_id, - hostname=host_detail.get('hostname', 'unknown'), - display_name=host_detail.get('display_name', 'unknown'), - authorization_reason=result.reason - )) + authorized_hosts.append( + AuthorizedHost( + host_id=host_id, + hostname=host_detail.get("hostname", "unknown"), + display_name=host_detail.get("display_name", "unknown"), + authorization_reason=result.reason, + ) + ) else: - authorization_failures.append(AuthorizationFailure( - host_id=host_id, - hostname=host_detail.get('hostname', 'unknown'), - reason=result.reason, - user_id=user_id - )) - + authorization_failures.append( + AuthorizationFailure( + host_id=host_id, + hostname=host_detail.get("hostname", "unknown"), + reason=result.reason, + user_id=user_id, + ) + ) + # Log authorization summary for security monitoring - logger.info(f"Bulk authorization results for user {user_id}: " - f"{len(authorized_hosts)} authorized, {len(authorization_failures)} denied, " - f"evaluation time: {auth_result.total_evaluation_time_ms}ms") - + logger.info( + f"Bulk authorization results for user {user_id}: " + f"{len(authorized_hosts)} authorized, {len(authorization_failures)} denied, " + f"evaluation time: {auth_result.total_evaluation_time_ms}ms" + ) + # Log denied hosts for security audit if authorization_failures: denied_host_ids = [f.host_id for f in authorization_failures] - logger.warning(f"Authorization denied for user {user_id} on hosts: {denied_host_ids}") - + logger.warning( + f"Authorization denied for user {user_id} on hosts: {denied_host_ids}" + ) + return authorized_hosts, authorization_failures - + except Exception as e: logger.error(f"Bulk authorization validation failed: {e}") - + # Fail securely - treat all hosts as unauthorized host_details = await self._get_host_details(host_ids) authorization_failures = [ AuthorizationFailure( - host_id=host_detail['id'], - hostname=host_detail.get('hostname', 'unknown'), + host_id=host_detail["id"], + hostname=host_detail.get("hostname", "unknown"), reason=f"Authorization system error: {str(e)}", - user_id=user_id + user_id=user_id, ) for host_detail in host_details ] - + return [], authorization_failures - - async def _build_user_authorization_context(self, user_id: str) -> AuthorizationContext: + + def _build_user_authorization_context(self, user_id: str) -> AuthorizationContext: """ Build authorization context for a user including roles and groups """ try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT u.id, u.username, u.role, COALESCE( JSON_AGG(DISTINCT ug.name) FILTER (WHERE ug.name IS NOT NULL), @@ -742,134 +825,143 @@ async def _build_user_authorization_context(self, user_id: str) -> Authorization LEFT JOIN user_groups ug ON ugm.group_id = ug.id WHERE u.id = :user_id AND u.is_active = true GROUP BY u.id, u.username, u.role - """), {"user_id": user_id}) - + """ + ), + {"user_id": user_id}, + ) + row = result.fetchone() if not row: logger.warning(f"User {user_id} not found or inactive") - return AuthorizationContext( - user_id=user_id, - user_roles=[], - user_groups=[] - ) - + return AuthorizationContext(user_id=user_id, user_roles=[], user_groups=[]) + user_groups = json.loads(row.user_groups) if row.user_groups else [] - + return AuthorizationContext( - user_id=user_id, - user_roles=[row.role] if row.role else [], - user_groups=user_groups + user_id=user_id, user_roles=[row.role] if row.role else [], user_groups=user_groups ) - + except Exception as e: logger.error(f"Error building authorization context for user {user_id}: {e}") - return AuthorizationContext( - user_id=user_id, - user_roles=[], - user_groups=[] - ) - - async def _get_host_details(self, host_ids: List[str]) -> List[Dict]: + return AuthorizationContext(user_id=user_id, user_roles=[], user_groups=[]) + + def _get_host_details(self, host_ids: List[str]) -> List[Dict]: """ Get host details for authorization results """ try: if not host_ids: return [] - + # Create placeholders for the IN clause - placeholders = ','.join([f"'{host_id}'" for host_id in host_ids]) - - result = self.db.execute(text(f""" + placeholders = ",".join([f"'{host_id}'" for host_id in host_ids]) + + result = self.db.execute( + text( + f""" SELECT id, hostname, display_name, ip_address, status FROM hosts WHERE id IN ({placeholders}) - """)) - + """ + ) + ) + return [ { - 'id': str(row.id), - 'hostname': row.hostname, - 'display_name': row.display_name or row.hostname, - 'ip_address': row.ip_address, - 'status': row.status + "id": str(row.id), + "hostname": row.hostname, + "display_name": row.display_name or row.hostname, + "ip_address": row.ip_address, + "status": row.status, } for row in result ] - + except Exception as e: logger.error(f"Error getting host details: {e}") - return [{'id': host_id, 'hostname': 'unknown', 'display_name': 'unknown'} for host_id in host_ids] - - async def _create_batch_scans_with_authorization( + return [ + {"id": host_id, "hostname": "unknown", "display_name": "unknown"} + for host_id in host_ids + ] + + def _create_batch_scans_with_authorization( self, batch: ScanBatch, session_id: str, name_prefix: str, user_id: str, stagger_delay: int, - authorized_hosts: List['AuthorizedHost'] + authorized_hosts: List["AuthorizedHost"], ) -> List[str]: """ Create batch scans with additional authorization validation - + This method provides a final authorization check before scan creation to ensure only authorized hosts have scans created. """ try: # Create lookup of authorized host IDs authorized_host_ids = {host.host_id for host in authorized_hosts} - + scan_ids = [] - + for i, host in enumerate(batch.hosts): # Additional authorization check if host.id not in authorized_host_ids: logger.warning(f"Skipping scan creation for unauthorized host {host.id}") continue - + scan_id = str(uuid.uuid4()) scan_name = f"{name_prefix} - {host.hostname}" - + # Calculate staggered start time start_delay = i * stagger_delay # seconds - + # Create scan record with authorization metadata - self.db.execute(text(""" + self.db.execute( + text( + """ INSERT INTO scans (id, name, host_id, content_id, profile_id, status, progress, scan_options, started_by, started_at, remediation_requested, verification_scan) VALUES (:id, :name, :host_id, :content_id, :profile_id, :status, :progress, :scan_options, :started_by, :started_at, :remediation_requested, :verification_scan) - """), { - "id": scan_id, - "name": scan_name, - "host_id": host.id, - "content_id": batch.content_id, - "profile_id": batch.profile_id, - "status": "pending", - "progress": 0, - "scan_options": json.dumps({ - "bulk_scan": True, - "session_id": session_id, - "batch_id": batch.id, - "start_delay": start_delay, - "authorized": True, # Mark as explicitly authorized - "authorization_timestamp": datetime.utcnow().isoformat() - }), - "started_by": user_id, - "started_at": datetime.utcnow(), - "remediation_requested": False, - "verification_scan": False - }) - + """ + ), + { + "id": scan_id, + "name": scan_name, + "host_id": host.id, + "content_id": batch.content_id, + "profile_id": batch.profile_id, + "status": "pending", + "progress": 0, + "scan_options": json.dumps( + { + "bulk_scan": True, + "session_id": session_id, + "batch_id": batch.id, + "start_delay": start_delay, + "authorized": True, # Mark as explicitly authorized + "authorization_timestamp": datetime.utcnow().isoformat(), + } + ), + "started_by": user_id, + "started_at": datetime.utcnow(), + "remediation_requested": False, + "verification_scan": False, + }, + ) + scan_ids.append(scan_id) - + self.db.commit() - - logger.info(f"Created {len(scan_ids)} authorized scans out of {len(batch.hosts)} hosts in batch") + + logger.info( + f"Created {len(scan_ids)} authorized scans out of {len(batch.hosts)} hosts in batch" + ) return scan_ids - + except Exception as e: logger.error(f"Error creating authorized batch scans: {e}") raise @@ -878,12 +970,13 @@ async def _create_batch_scans_with_authorization( @dataclass class AuthorizedHost: """Represents a host that has been authorized for scanning""" + host_id: str hostname: str display_name: str authorization_reason: str timestamp: datetime = None - + def __post_init__(self): if self.timestamp is None: - self.timestamp = datetime.utcnow() \ No newline at end of file + self.timestamp = datetime.utcnow() diff --git a/backend/app/services/command_sandbox.py b/backend/app/services/command_sandbox.py index 6510da90..5e874726 100644 --- a/backend/app/services/command_sandbox.py +++ b/backend/app/services/command_sandbox.py @@ -6,7 +6,7 @@ Security Features: - Containerized execution environment -- Cryptographic command signature verification +- Cryptographic command signature verification - Command allowlisting with parameter validation - Multi-factor approval workflow for privileged operations - Complete audit trail and rollback capabilities @@ -41,14 +41,16 @@ class CommandSecurityLevel(str, Enum): """Security classification levels for commands""" - SAFE = "safe" # No system modifications, read-only operations - MODERATE = "moderate" # Limited system changes, no privilege escalation + + SAFE = "safe" # No system modifications, read-only operations + MODERATE = "moderate" # Limited system changes, no privilege escalation PRIVILEGED = "privileged" # Requires elevated privileges, system changes - CRITICAL = "critical" # High-impact system modifications + CRITICAL = "critical" # High-impact system modifications class ExecutionStatus(str, Enum): """Command execution status tracking""" + PENDING_APPROVAL = "pending_approval" APPROVED = "approved" REJECTED = "rejected" @@ -60,6 +62,7 @@ class ExecutionStatus(str, Enum): class SecureCommand(BaseModel): """Secure command definition with cryptographic verification""" + command_id: str template: str description: str @@ -75,6 +78,7 @@ class SecureCommand(BaseModel): class ExecutionRequest(BaseModel): """Command execution request with security context""" + request_id: str = Field(default_factory=lambda: str(uuid.uuid4())) command_id: str parameters: Dict[str, Any] = Field(default_factory=dict) @@ -94,22 +98,22 @@ class ExecutionRequest(BaseModel): class SandboxEnvironment: """Containerized sandbox environment for secure command execution""" - + def __init__(self, container_image: str = "ubuntu:22.04"): self.container_image = container_image self.docker_client = docker.from_env() self.container = None self.sandbox_id = str(uuid.uuid4()) - + async def __aenter__(self): """Async context manager entry - create sandbox""" await self._create_sandbox() return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit - cleanup sandbox""" await self._cleanup_sandbox() - + async def _create_sandbox(self): """Create secure containerized environment""" try: @@ -130,65 +134,65 @@ async def _create_sandbox(self): pids_limit=100, # Process limit tmpfs={"/tmp": "size=100m,noexec"}, # Temporary filesystem ) - + # Wait for container to be ready await asyncio.sleep(1) - + # Install basic tools in sandbox await self._setup_sandbox_tools() - + logger.info(f"Created secure sandbox: {self.sandbox_id}") - + except Exception as e: logger.error(f"Failed to create sandbox {self.sandbox_id}: {e}") raise - - async def _setup_sandbox_tools(self): + + def _setup_sandbox_tools(self): """Install essential tools in sandbox""" try: # Update package lists result = self.container.exec_run("apt-get update -qq") if result.exit_code != 0: logger.warning("Failed to update package lists in sandbox") - + # Install essential tools essential_tools = [ - "curl", "wget", "netcat-traditional", "dnsutils", - "net-tools", "procps", "lsof" + "curl", + "wget", + "netcat-traditional", + "dnsutils", + "net-tools", + "procps", + "lsof", ] - + for tool in essential_tools: result = self.container.exec_run(f"apt-get install -y -qq {tool}") if result.exit_code != 0: logger.warning(f"Failed to install {tool} in sandbox") - + except Exception as e: logger.warning(f"Failed to setup sandbox tools: {e}") - - async def execute_command(self, command: str, timeout: int = 300) -> Tuple[int, str, str]: + + def execute_command(self, command: str, timeout: int = 300) -> Tuple[int, str, str]: """Execute command in sandbox with timeout""" if not self.container: raise RuntimeError("Sandbox not initialized") - + try: # Execute command with timeout - result = self.container.exec_run( - command, - timeout=timeout, - stdout=True, - stderr=True - ) - - stdout = result.output.decode('utf-8') if result.output else "" + result = self.container.exec_run(command, timeout=timeout, stdout=True, stderr=True) + + stdout = result.output.decode("utf-8") if result.output else "" stderr = "" # Docker exec_run combines stdout/stderr - + return result.exit_code, stdout, stderr - + except Exception as e: logger.error(f"Command execution failed in sandbox {self.sandbox_id}: {e}") raise - - async def _cleanup_sandbox(self): + + def _cleanup_sandbox(self): """Clean up sandbox container""" try: if self.container: @@ -200,80 +204,71 @@ async def _cleanup_sandbox(self): class CommandSignatureService: """Cryptographic signature service for command verification""" - + def __init__(self, crypto_service: CryptoService): self.crypto_service = crypto_service self.signature_algorithm = hashes.SHA256() - + def sign_command(self, command: SecureCommand, private_key_path: str) -> str: """Generate cryptographic signature for command""" try: # Load private key - with open(private_key_path, 'rb') as f: - private_key = serialization.load_pem_private_key( - f.read(), - password=None - ) - + with open(private_key_path, "rb") as f: + private_key = serialization.load_pem_private_key(f.read(), password=None) + # Create command payload for signing payload = { "command_id": command.command_id, "template": command.template, "security_level": command.security_level, "allowed_parameters": command.allowed_parameters, - "parameter_patterns": command.parameter_patterns + "parameter_patterns": command.parameter_patterns, } - + payload_bytes = json.dumps(payload, sort_keys=True).encode() - + # Generate signature signature = private_key.sign( payload_bytes, - padding.PSS( - mgf=padding.MGF1(hashes.SHA256()), - salt_length=padding.PSS.MAX_LENGTH - ), - hashes.SHA256() + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), + hashes.SHA256(), ) - + return signature.hex() - + except Exception as e: logger.error(f"Failed to sign command {command.command_id}: {e}") raise - + def verify_command(self, command: SecureCommand, signature: str, public_key_path: str) -> bool: """Verify cryptographic signature for command""" try: # Load public key - with open(public_key_path, 'rb') as f: + with open(public_key_path, "rb") as f: public_key = serialization.load_pem_public_key(f.read()) - + # Recreate command payload payload = { "command_id": command.command_id, "template": command.template, "security_level": command.security_level, "allowed_parameters": command.allowed_parameters, - "parameter_patterns": command.parameter_patterns + "parameter_patterns": command.parameter_patterns, } - + payload_bytes = json.dumps(payload, sort_keys=True).encode() signature_bytes = bytes.fromhex(signature) - + # Verify signature public_key.verify( signature_bytes, payload_bytes, - padding.PSS( - mgf=padding.MGF1(hashes.SHA256()), - salt_length=padding.PSS.MAX_LENGTH - ), - hashes.SHA256() + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH), + hashes.SHA256(), ) - + return True - + except Exception as e: logger.warning(f"Command signature verification failed for {command.command_id}: {e}") return False @@ -281,14 +276,14 @@ def verify_command(self, command: SecureCommand, signature: str, public_key_path class CommandSandboxService: """Main service for secure command sandboxing""" - + def __init__(self): self.crypto_service = CryptoService() self.signature_service = CommandSignatureService(self.crypto_service) self.allowed_commands: Dict[str, SecureCommand] = {} self.execution_requests: Dict[str, ExecutionRequest] = {} self._load_allowed_commands() - + def _load_allowed_commands(self): """Load pre-approved secure commands""" # Define secure command templates with strict parameter validation @@ -301,7 +296,7 @@ def _load_allowed_commands(self): allowed_parameters=["service_name"], parameter_patterns={"service_name": r"^[a-zA-Z0-9\-_.]+$"}, requires_approval=False, - max_execution_time=30 + max_execution_time=30, ), SecureCommand( command_id="check_network_port", @@ -311,7 +306,7 @@ def _load_allowed_commands(self): allowed_parameters=["port"], parameter_patterns={"port": r"^[1-9][0-9]{0,4}$"}, requires_approval=False, - max_execution_time=30 + max_execution_time=30, ), SecureCommand( command_id="install_openscap_ubuntu", @@ -321,17 +316,17 @@ def _load_allowed_commands(self): allowed_parameters=[], requires_approval=True, max_execution_time=300, - rollback_template="apt-get remove -y libopenscap8 ssg-base" + rollback_template="apt-get remove -y libopenscap8 ssg-base", ), SecureCommand( command_id="install_openscap_rhel", template="yum install -y openscap-scanner scap-security-guide", - description="Install OpenSCAP on RHEL/CentOS systems", + description="Install OpenSCAP on RHEL/CentOS systems", security_level=CommandSecurityLevel.PRIVILEGED, allowed_parameters=[], requires_approval=True, max_execution_time=300, - rollback_template="yum remove -y openscap-scanner scap-security-guide" + rollback_template="yum remove -y openscap-scanner scap-security-guide", ), SecureCommand( command_id="cleanup_temp_files", @@ -340,54 +335,62 @@ def _load_allowed_commands(self): security_level=CommandSecurityLevel.MODERATE, allowed_parameters=[], requires_approval=True, - max_execution_time=120 - ) + max_execution_time=120, + ), ] - + for cmd in commands: self.allowed_commands[cmd.command_id] = cmd - + logger.info(f"Loaded {len(self.allowed_commands)} secure command templates") - + def validate_command_parameters(self, command_id: str, parameters: Dict[str, Any]) -> bool: """Validate command parameters against allowed patterns""" if command_id not in self.allowed_commands: return False - + command = self.allowed_commands[command_id] - + # Check all required parameters are present for param in command.allowed_parameters: if param not in parameters: logger.warning(f"Missing required parameter {param} for command {command_id}") return False - + # Validate parameter patterns for param, value in parameters.items(): if param not in command.allowed_parameters: logger.warning(f"Unauthorized parameter {param} for command {command_id}") return False - + if param in command.parameter_patterns: pattern = command.parameter_patterns[param] import re + if not re.match(pattern, str(value)): - logger.warning(f"Parameter {param} value '{value}' doesn't match pattern {pattern}") + logger.warning( + f"Parameter {param} value '{value}' doesn't match pattern {pattern}" + ) return False - + return True - - async def request_command_execution(self, command_id: str, parameters: Dict[str, Any], - target_host: str, requested_by: str, - justification: str) -> ExecutionRequest: + + def request_command_execution( + self, + command_id: str, + parameters: Dict[str, Any], + target_host: str, + requested_by: str, + justification: str, + ) -> ExecutionRequest: """Request execution of a secure command""" - + # Validate command exists and parameters are valid if not self.validate_command_parameters(command_id, parameters): raise ValueError(f"Invalid command or parameters for {command_id}") - + command = self.allowed_commands[command_id] - + # Create execution request request = ExecutionRequest( command_id=command_id, @@ -395,65 +398,70 @@ async def request_command_execution(self, command_id: str, parameters: Dict[str, target_host=target_host, requested_by=requested_by, justification=justification, - status=ExecutionStatus.PENDING_APPROVAL if command.requires_approval else ExecutionStatus.APPROVED + status=( + ExecutionStatus.PENDING_APPROVAL + if command.requires_approval + else ExecutionStatus.APPROVED + ), ) - + self.execution_requests[request.request_id] = request - + # Log security event - logger.info(f"Command execution requested: {command_id} by {requested_by} for {target_host}") - + logger.info( + f"Command execution requested: {command_id} by {requested_by} for {target_host}" + ) + return request - - async def approve_request(self, request_id: str, approved_by: str) -> bool: + + def approve_request(self, request_id: str, approved_by: str) -> bool: """Approve a pending execution request""" if request_id not in self.execution_requests: return False - + request = self.execution_requests[request_id] if request.status != ExecutionStatus.PENDING_APPROVAL: return False - + request.status = ExecutionStatus.APPROVED request.approved_by = approved_by - + logger.info(f"Command execution approved: {request.command_id} by {approved_by}") return True - + async def execute_secure_command(self, request_id: str) -> ExecutionRequest: """Execute approved command in secure sandbox""" if request_id not in self.execution_requests: raise ValueError(f"Execution request {request_id} not found") - + request = self.execution_requests[request_id] if request.status != ExecutionStatus.APPROVED: raise ValueError(f"Request {request_id} not approved for execution") - + command = self.allowed_commands[request.command_id] - + try: request.status = ExecutionStatus.EXECUTING request.executed_at = datetime.utcnow() - + # Build command from template command_str = command.template for param, value in request.parameters.items(): command_str = command_str.replace(f"{{{param}}}", str(value)) - + logger.info(f"Executing secure command: {command_str}") - + # Execute in sandbox async with SandboxEnvironment() as sandbox: exit_code, stdout, stderr = await sandbox.execute_command( - command_str, - timeout=command.max_execution_time + command_str, timeout=command.max_execution_time ) - + request.exit_code = exit_code request.output = stdout request.error_output = stderr request.completed_at = datetime.utcnow() - + if exit_code == 0: request.status = ExecutionStatus.COMPLETED # Set up rollback if available @@ -462,61 +470,67 @@ async def execute_secure_command(self, request_id: str) -> ExecutionRequest: request.rollback_command = command.rollback_template else: request.status = ExecutionStatus.FAILED - - logger.info(f"Command execution completed: {request.command_id} (exit_code: {exit_code})") - + + logger.info( + f"Command execution completed: {request.command_id} (exit_code: {exit_code})" + ) + except Exception as e: request.status = ExecutionStatus.FAILED request.error_output = str(e) request.completed_at = datetime.utcnow() logger.error(f"Command execution failed: {request.command_id} - {e}") - + return request - + async def rollback_execution(self, request_id: str, rollback_by: str) -> bool: """Rollback a previously executed command""" if request_id not in self.execution_requests: return False - + request = self.execution_requests[request_id] if not request.rollback_available or not request.rollback_command: return False - + try: logger.info(f"Rolling back command execution: {request.command_id}") - + # Execute rollback in sandbox async with SandboxEnvironment() as sandbox: exit_code, stdout, stderr = await sandbox.execute_command( - request.rollback_command, - timeout=300 + request.rollback_command, timeout=300 ) - + if exit_code == 0: request.status = ExecutionStatus.ROLLED_BACK logger.info(f"Command rollback successful: {request.command_id}") return True else: - logger.error(f"Command rollback failed: {request.command_id} (exit_code: {exit_code})") + logger.error( + f"Command rollback failed: {request.command_id} (exit_code: {exit_code})" + ) return False - + except Exception as e: logger.error(f"Command rollback error: {request.command_id} - {e}") return False - + def get_execution_request(self, request_id: str) -> Optional[ExecutionRequest]: """Get execution request by ID""" return self.execution_requests.get(request_id) - + def list_pending_approvals(self) -> List[ExecutionRequest]: """List all pending approval requests""" - return [req for req in self.execution_requests.values() - if req.status == ExecutionStatus.PENDING_APPROVAL] - + return [ + req + for req in self.execution_requests.values() + if req.status == ExecutionStatus.PENDING_APPROVAL + ] + def get_command_info(self, command_id: str) -> Optional[SecureCommand]: """Get information about a secure command""" return self.allowed_commands.get(command_id) - + def list_available_commands(self) -> List[SecureCommand]: """List all available secure commands""" - return list(self.allowed_commands.values()) \ No newline at end of file + return list(self.allowed_commands.values()) diff --git a/backend/app/services/compliance_framework_mapper.py b/backend/app/services/compliance_framework_mapper.py index 081f78e5..6714ba0d 100644 --- a/backend/app/services/compliance_framework_mapper.py +++ b/backend/app/services/compliance_framework_mapper.py @@ -2,6 +2,7 @@ Compliance Framework Mapper Service Maps SCAP rules to multiple compliance frameworks (NIST, CIS, STIG, CMMC 2.0) """ + import json import logging from typing import Dict, List, Optional, Set, Tuple @@ -14,6 +15,7 @@ class ComplianceFramework(Enum): """Supported compliance frameworks""" + DISA_STIG = "DISA-STIG" NIST_800_53 = "NIST-800-53" CIS_CONTROLS = "CIS-Controls" @@ -27,6 +29,7 @@ class ComplianceFramework(Enum): @dataclass class FrameworkMapping: """Framework mapping details""" + framework: ComplianceFramework control_id: str control_title: str @@ -36,11 +39,12 @@ class FrameworkMapping: related_controls: List[str] severity: str # low, medium, high, critical maturity_level: int # 1-5 for CMMC - + @dataclass class ComplianceControl: """Unified compliance control across frameworks""" + rule_id: str title: str description: str @@ -53,12 +57,12 @@ class ComplianceControl: class ComplianceFrameworkMapper: """Service for mapping SCAP rules to compliance frameworks""" - + def __init__(self): self.framework_mappings = self._initialize_mappings() self.control_families = self._initialize_control_families() self.cmmc_practices = self._initialize_cmmc_practices() - + def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: """Initialize comprehensive framework mappings""" # This would be loaded from a database or configuration file @@ -72,10 +76,13 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="SSH daemon must disable root login", control_family="Access Control", implementation_guidance="Configure SSH daemon to prevent root login by setting PermitRootLogin to 'no' in /etc/ssh/sshd_config", - assessment_objectives=["Verify PermitRootLogin is set to 'no'", "Verify SSH service is restarted after changes"], + assessment_objectives=[ + "Verify PermitRootLogin is set to 'no'", + "Verify SSH service is restarted after changes", + ], related_controls=["AC-6", "IA-2"], severity="high", - maturity_level=3 + maturity_level=3, ), FrameworkMapping( framework=ComplianceFramework.NIST_800_53, @@ -83,10 +90,13 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="Non-Privileged Access for Nonsecurity Functions", control_family="Access Control", implementation_guidance="Require users to use non-privileged accounts when accessing nonsecurity functions", - assessment_objectives=["Verify root access is restricted", "Ensure privilege escalation is controlled"], + assessment_objectives=[ + "Verify root access is restricted", + "Ensure privilege escalation is controlled", + ], related_controls=["AC-6", "AC-6(1)", "AC-6(5)"], severity="high", - maturity_level=3 + maturity_level=3, ), FrameworkMapping( framework=ComplianceFramework.CIS_CONTROLS, @@ -94,10 +104,13 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="Restrict Administrator Privileges to Dedicated Administrator Accounts", control_family="Account Management", implementation_guidance="Ensure administrative privileges are restricted to dedicated admin accounts", - assessment_objectives=["Verify separation of admin and user accounts", "Confirm root login restrictions"], + assessment_objectives=[ + "Verify separation of admin and user accounts", + "Confirm root login restrictions", + ], related_controls=["5.1", "5.2", "5.3"], severity="high", - maturity_level=3 + maturity_level=3, ), FrameworkMapping( framework=ComplianceFramework.CMMC_2_0, @@ -105,13 +118,15 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="Employ the principle of least privilege", control_family="Access Control", implementation_guidance="Employ the principle of least privilege, including for specific security functions and privileged accounts", - assessment_objectives=["Verify least privilege implementation", "Assess privileged account restrictions"], + assessment_objectives=[ + "Verify least privilege implementation", + "Assess privileged account restrictions", + ], related_controls=["AC.L1-3.1.1", "AC.L2-3.1.6"], severity="high", - maturity_level=2 - ) + maturity_level=2, + ), ], - # Password Policy Controls "xccdf_mil.disa.stig_rule_SV-230365r792936_rule": [ FrameworkMapping( @@ -120,10 +135,13 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="System must enforce minimum password length", control_family="Identification and Authentication", implementation_guidance="Configure PAM to enforce minimum password length of 15 characters", - assessment_objectives=["Verify password length configuration", "Test password creation with various lengths"], + assessment_objectives=[ + "Verify password length configuration", + "Test password creation with various lengths", + ], related_controls=["IA-5"], severity="medium", - maturity_level=2 + maturity_level=2, ), FrameworkMapping( framework=ComplianceFramework.NIST_800_53, @@ -131,10 +149,13 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="Password-Based Authentication - Complexity", control_family="Identification and Authentication", implementation_guidance="Enforce minimum password complexity requirements including length", - assessment_objectives=["Verify password complexity settings", "Validate enforcement mechanisms"], + assessment_objectives=[ + "Verify password complexity settings", + "Validate enforcement mechanisms", + ], related_controls=["IA-5", "IA-5(1)"], severity="medium", - maturity_level=2 + maturity_level=2, ), FrameworkMapping( framework=ComplianceFramework.CIS_CONTROLS, @@ -142,10 +163,13 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="Use Unique Passwords", control_family="Account Management", implementation_guidance="Ensure all accounts have unique, complex passwords meeting minimum requirements", - assessment_objectives=["Verify password policy enforcement", "Check password uniqueness requirements"], + assessment_objectives=[ + "Verify password policy enforcement", + "Check password uniqueness requirements", + ], related_controls=["5.1", "5.3"], severity="medium", - maturity_level=2 + maturity_level=2, ), FrameworkMapping( framework=ComplianceFramework.CMMC_2_0, @@ -153,13 +177,15 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="Enforce a minimum password complexity", control_family="Identification and Authentication", implementation_guidance="Enforce a minimum password complexity and change of characters when new passwords are created", - assessment_objectives=["Verify password complexity requirements", "Test password change enforcement"], + assessment_objectives=[ + "Verify password complexity requirements", + "Test password change enforcement", + ], related_controls=["IA.L1-3.5.1", "IA.L1-3.5.2"], severity="medium", - maturity_level=2 - ) + maturity_level=2, + ), ], - # Audit Configuration Controls "xccdf_mil.disa.stig_rule_SV-230423r793041_rule": [ FrameworkMapping( @@ -168,10 +194,13 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="Audit daemon must be enabled", control_family="Audit and Accountability", implementation_guidance="Enable and configure auditd service to capture security-relevant events", - assessment_objectives=["Verify auditd service is enabled", "Confirm audit rules are loaded"], + assessment_objectives=[ + "Verify auditd service is enabled", + "Confirm audit rules are loaded", + ], related_controls=["AU-12", "AU-3"], severity="high", - maturity_level=2 + maturity_level=2, ), FrameworkMapping( framework=ComplianceFramework.NIST_800_53, @@ -179,10 +208,13 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="Audit Record Generation", control_family="Audit and Accountability", implementation_guidance="Generate audit records for security-relevant events", - assessment_objectives=["Verify audit capability", "Confirm event capture configuration"], + assessment_objectives=[ + "Verify audit capability", + "Confirm event capture configuration", + ], related_controls=["AU-2", "AU-3", "AU-12(1)"], severity="high", - maturity_level=2 + maturity_level=2, ), FrameworkMapping( framework=ComplianceFramework.CIS_CONTROLS, @@ -193,7 +225,7 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: assessment_objectives=["Verify log collection", "Assess log completeness"], related_controls=["8.1", "8.3", "8.4"], severity="high", - maturity_level=2 + maturity_level=2, ), FrameworkMapping( framework=ComplianceFramework.CMMC_2_0, @@ -201,14 +233,17 @@ def _initialize_mappings(self) -> Dict[str, List[FrameworkMapping]]: control_title="Create and retain system audit logs", control_family="Audit and Accountability", implementation_guidance="Create and retain system audit logs and records to monitor, analyze, investigate, and report unlawful or unauthorized activity", - assessment_objectives=["Verify audit log generation", "Confirm retention policies"], + assessment_objectives=[ + "Verify audit log generation", + "Confirm retention policies", + ], related_controls=["AU.L2-3.3.2", "AU.L2-3.3.3"], severity="high", - maturity_level=2 - ) - ] + maturity_level=2, + ), + ], } - + def _initialize_control_families(self) -> Dict[ComplianceFramework, List[str]]: """Initialize control families for each framework""" return { @@ -231,7 +266,7 @@ def _initialize_control_families(self) -> Dict[ComplianceFramework, List[str]]: "System and Services Acquisition (SA)", "System and Communications Protection (SC)", "System and Information Integrity (SI)", - "Supply Chain Risk Management (SR)" + "Supply Chain Risk Management (SR)", ], ComplianceFramework.CIS_CONTROLS: [ "Inventory and Control of Enterprise Assets", @@ -251,7 +286,7 @@ def _initialize_control_families(self) -> Dict[ComplianceFramework, List[str]]: "Service Provider Management", "Application Software Security", "Incident Response Management", - "Penetration Testing" + "Penetration Testing", ], ComplianceFramework.CMMC_2_0: [ "Access Control (AC)", @@ -267,85 +302,168 @@ def _initialize_control_families(self) -> Dict[ComplianceFramework, List[str]]: "Risk Assessment (RA)", "Security Assessment (CA)", "System and Communications Protection (SC)", - "System and Information Integrity (SI)" - ] + "System and Information Integrity (SI)", + ], } - + def _initialize_cmmc_practices(self) -> Dict[int, List[str]]: """Initialize CMMC maturity level practices""" return { 1: [ # Foundational - "AC.L1-3.1.1", "AC.L1-3.1.2", "AC.L1-3.1.20", "AC.L1-3.1.22", - "IA.L1-3.5.1", "IA.L1-3.5.2", + "AC.L1-3.1.1", + "AC.L1-3.1.2", + "AC.L1-3.1.20", + "AC.L1-3.1.22", + "IA.L1-3.5.1", + "IA.L1-3.5.2", "MP.L1-3.8.3", - "PE.L1-3.10.1", "PE.L1-3.10.3", "PE.L1-3.10.4", "PE.L1-3.10.5", - "SC.L1-3.13.1", "SC.L1-3.13.5", - "SI.L1-3.14.1", "SI.L1-3.14.2", "SI.L1-3.14.3" + "PE.L1-3.10.1", + "PE.L1-3.10.3", + "PE.L1-3.10.4", + "PE.L1-3.10.5", + "SC.L1-3.13.1", + "SC.L1-3.13.5", + "SI.L1-3.14.1", + "SI.L1-3.14.2", + "SI.L1-3.14.3", ], 2: [ # Advanced - "AC.L2-3.1.3", "AC.L2-3.1.4", "AC.L2-3.1.5", "AC.L2-3.1.6", "AC.L2-3.1.7", - "AC.L2-3.1.8", "AC.L2-3.1.9", "AC.L2-3.1.10", "AC.L2-3.1.11", "AC.L2-3.1.12", - "AT.L2-3.2.1", "AT.L2-3.2.2", "AT.L2-3.2.3", - "AU.L2-3.3.1", "AU.L2-3.3.2", "AU.L2-3.3.3", "AU.L2-3.3.4", "AU.L2-3.3.5", - "CM.L2-3.4.1", "CM.L2-3.4.2", "CM.L2-3.4.3", "CM.L2-3.4.4", "CM.L2-3.4.5", - "IA.L2-3.5.3", "IA.L2-3.5.4", "IA.L2-3.5.5", "IA.L2-3.5.6", "IA.L2-3.5.7", - "IR.L2-3.6.1", "IR.L2-3.6.2", "IR.L2-3.6.3", - "MA.L2-3.7.1", "MA.L2-3.7.2", "MA.L2-3.7.3", "MA.L2-3.7.4", "MA.L2-3.7.5", - "MP.L2-3.8.1", "MP.L2-3.8.2", "MP.L2-3.8.4", "MP.L2-3.8.5", "MP.L2-3.8.6", - "PE.L2-3.10.2", "PE.L2-3.10.6", - "PS.L2-3.9.1", "PS.L2-3.9.2", - "RA.L2-3.11.1", "RA.L2-3.11.2", "RA.L2-3.11.3", - "CA.L2-3.12.1", "CA.L2-3.12.2", "CA.L2-3.12.3", "CA.L2-3.12.4", - "SC.L2-3.13.2", "SC.L2-3.13.3", "SC.L2-3.13.4", "SC.L2-3.13.6", "SC.L2-3.13.7", - "SI.L2-3.14.4", "SI.L2-3.14.5", "SI.L2-3.14.6", "SI.L2-3.14.7" + "AC.L2-3.1.3", + "AC.L2-3.1.4", + "AC.L2-3.1.5", + "AC.L2-3.1.6", + "AC.L2-3.1.7", + "AC.L2-3.1.8", + "AC.L2-3.1.9", + "AC.L2-3.1.10", + "AC.L2-3.1.11", + "AC.L2-3.1.12", + "AT.L2-3.2.1", + "AT.L2-3.2.2", + "AT.L2-3.2.3", + "AU.L2-3.3.1", + "AU.L2-3.3.2", + "AU.L2-3.3.3", + "AU.L2-3.3.4", + "AU.L2-3.3.5", + "CM.L2-3.4.1", + "CM.L2-3.4.2", + "CM.L2-3.4.3", + "CM.L2-3.4.4", + "CM.L2-3.4.5", + "IA.L2-3.5.3", + "IA.L2-3.5.4", + "IA.L2-3.5.5", + "IA.L2-3.5.6", + "IA.L2-3.5.7", + "IR.L2-3.6.1", + "IR.L2-3.6.2", + "IR.L2-3.6.3", + "MA.L2-3.7.1", + "MA.L2-3.7.2", + "MA.L2-3.7.3", + "MA.L2-3.7.4", + "MA.L2-3.7.5", + "MP.L2-3.8.1", + "MP.L2-3.8.2", + "MP.L2-3.8.4", + "MP.L2-3.8.5", + "MP.L2-3.8.6", + "PE.L2-3.10.2", + "PE.L2-3.10.6", + "PS.L2-3.9.1", + "PS.L2-3.9.2", + "RA.L2-3.11.1", + "RA.L2-3.11.2", + "RA.L2-3.11.3", + "CA.L2-3.12.1", + "CA.L2-3.12.2", + "CA.L2-3.12.3", + "CA.L2-3.12.4", + "SC.L2-3.13.2", + "SC.L2-3.13.3", + "SC.L2-3.13.4", + "SC.L2-3.13.6", + "SC.L2-3.13.7", + "SI.L2-3.14.4", + "SI.L2-3.14.5", + "SI.L2-3.14.6", + "SI.L2-3.14.7", ], 3: [ # Expert (includes all L1 and L2 plus additional) - "AC.L3-3.1.13", "AC.L3-3.1.14", "AC.L3-3.1.15", "AC.L3-3.1.16", "AC.L3-3.1.17", + "AC.L3-3.1.13", + "AC.L3-3.1.14", + "AC.L3-3.1.15", + "AC.L3-3.1.16", + "AC.L3-3.1.17", "AT.L3-3.2.4", - "AU.L3-3.3.6", "AU.L3-3.3.7", "AU.L3-3.3.8", "AU.L3-3.3.9", - "CM.L3-3.4.6", "CM.L3-3.4.7", "CM.L3-3.4.8", "CM.L3-3.4.9", - "IA.L3-3.5.8", "IA.L3-3.5.9", "IA.L3-3.5.10", "IA.L3-3.5.11", "IA.L3-3.5.12", - "IR.L3-3.6.4", "IR.L3-3.6.5", + "AU.L3-3.3.6", + "AU.L3-3.3.7", + "AU.L3-3.3.8", + "AU.L3-3.3.9", + "CM.L3-3.4.6", + "CM.L3-3.4.7", + "CM.L3-3.4.8", + "CM.L3-3.4.9", + "IA.L3-3.5.8", + "IA.L3-3.5.9", + "IA.L3-3.5.10", + "IA.L3-3.5.11", + "IA.L3-3.5.12", + "IR.L3-3.6.4", + "IR.L3-3.6.5", "MA.L3-3.7.6", - "MP.L3-3.8.7", "MP.L3-3.8.8", "MP.L3-3.8.9", + "MP.L3-3.8.7", + "MP.L3-3.8.8", + "MP.L3-3.8.9", "PE.L3-3.10.7", - "RA.L3-3.11.4", "RA.L3-3.11.5", "RA.L3-3.11.6", "RA.L3-3.11.7", + "RA.L3-3.11.4", + "RA.L3-3.11.5", + "RA.L3-3.11.6", + "RA.L3-3.11.7", "CA.L3-3.12.5", - "SC.L3-3.13.8", "SC.L3-3.13.9", "SC.L3-3.13.10", "SC.L3-3.13.11", "SC.L3-3.13.12", - "SI.L3-3.14.8", "SI.L3-3.14.9", "SI.L3-3.14.10" - ] + "SC.L3-3.13.8", + "SC.L3-3.13.9", + "SC.L3-3.13.10", + "SC.L3-3.13.11", + "SC.L3-3.13.12", + "SI.L3-3.14.8", + "SI.L3-3.14.9", + "SI.L3-3.14.10", + ], } - + def map_scap_rule_to_frameworks(self, scap_rule_id: str) -> List[FrameworkMapping]: """Map a SCAP rule to all applicable compliance frameworks""" return self.framework_mappings.get(scap_rule_id, []) - - def get_unified_control(self, scap_rule_id: str, rule_title: str = "", - rule_description: str = "") -> Optional[ComplianceControl]: + + def get_unified_control( + self, scap_rule_id: str, rule_title: str = "", rule_description: str = "" + ) -> Optional[ComplianceControl]: """Get unified compliance control information across all frameworks""" mappings = self.map_scap_rule_to_frameworks(scap_rule_id) - + if not mappings: # Try to infer mappings from rule ID patterns mappings = self._infer_mappings_from_rule_id(scap_rule_id, rule_title) - + if not mappings: return None - + # Extract unique tags and categories tags = set() categories = set() - + for mapping in mappings: tags.add(mapping.control_family.lower().replace(" ", "_")) categories.add(mapping.control_family) - + # Add severity as tag tags.add(f"severity_{mapping.severity}") - + # Add framework as tag tags.add(mapping.framework.value.lower().replace("-", "_")) - + return ComplianceControl( rule_id=scap_rule_id, title=rule_title or mappings[0].control_title, @@ -354,77 +472,85 @@ def get_unified_control(self, scap_rule_id: str, rule_title: str = "", tags=list(tags), categories=list(categories), automated_remediation=self._check_automated_remediation(scap_rule_id), - aegis_rule_id=self._get_aegis_rule_id(scap_rule_id) + aegis_rule_id=self._get_aegis_rule_id(scap_rule_id), ) - - def _infer_mappings_from_rule_id(self, scap_rule_id: str, - rule_title: str) -> List[FrameworkMapping]: + + def _infer_mappings_from_rule_id( + self, scap_rule_id: str, rule_title: str + ) -> List[FrameworkMapping]: """Infer framework mappings from SCAP rule ID patterns""" mappings = [] - + # Extract STIG rule pattern - stig_match = re.search(r'SV-\d+r\d+', scap_rule_id) + stig_match = re.search(r"SV-\d+r\d+", scap_rule_id) if stig_match: stig_id = stig_match.group() - + # Infer control family from title control_family = self._infer_control_family(rule_title) severity = self._infer_severity(rule_title) - - mappings.append(FrameworkMapping( - framework=ComplianceFramework.DISA_STIG, - control_id=stig_id, - control_title=rule_title, - control_family=control_family, - implementation_guidance="Implement control as specified in STIG guidance", - assessment_objectives=["Verify control implementation", "Validate effectiveness"], - related_controls=[], - severity=severity, - maturity_level=2 - )) - - # Try to map to NIST based on common patterns - nist_control = self._infer_nist_control(rule_title, control_family) - if nist_control: - mappings.append(FrameworkMapping( - framework=ComplianceFramework.NIST_800_53, - control_id=nist_control, + + mappings.append( + FrameworkMapping( + framework=ComplianceFramework.DISA_STIG, + control_id=stig_id, control_title=rule_title, control_family=control_family, - implementation_guidance="Implement per NIST 800-53 guidelines", - assessment_objectives=["Verify NIST control implementation"], + implementation_guidance="Implement control as specified in STIG guidance", + assessment_objectives=[ + "Verify control implementation", + "Validate effectiveness", + ], related_controls=[], severity=severity, - maturity_level=2 - )) - + maturity_level=2, + ) + ) + + # Try to map to NIST based on common patterns + nist_control = self._infer_nist_control(rule_title, control_family) + if nist_control: + mappings.append( + FrameworkMapping( + framework=ComplianceFramework.NIST_800_53, + control_id=nist_control, + control_title=rule_title, + control_family=control_family, + implementation_guidance="Implement per NIST 800-53 guidelines", + assessment_objectives=["Verify NIST control implementation"], + related_controls=[], + severity=severity, + maturity_level=2, + ) + ) + return mappings - + def _infer_control_family(self, rule_title: str) -> str: """Infer control family from rule title""" title_lower = rule_title.lower() - - if any(word in title_lower for word in ['ssh', 'password', 'authentication', 'login']): + + if any(word in title_lower for word in ["ssh", "password", "authentication", "login"]): return "Identification and Authentication" - elif any(word in title_lower for word in ['audit', 'log', 'logging']): + elif any(word in title_lower for word in ["audit", "log", "logging"]): return "Audit and Accountability" - elif any(word in title_lower for word in ['access', 'permission', 'privilege']): + elif any(word in title_lower for word in ["access", "permission", "privilege"]): return "Access Control" - elif any(word in title_lower for word in ['firewall', 'network', 'port']): + elif any(word in title_lower for word in ["firewall", "network", "port"]): return "System and Communications Protection" - elif any(word in title_lower for word in ['update', 'patch', 'vulnerability']): + elif any(word in title_lower for word in ["update", "patch", "vulnerability"]): return "System and Information Integrity" else: return "Configuration Management" - + def _infer_severity(self, rule_title: str) -> str: """Infer severity from rule title keywords""" title_lower = rule_title.lower() - - critical_keywords = ['must not', 'prohibited', 'disabled', 'root', 'admin'] - high_keywords = ['must', 'required', 'audit', 'authentication', 'firewall'] - medium_keywords = ['should', 'recommended', 'configuration'] - + + critical_keywords = ["must not", "prohibited", "disabled", "root", "admin"] + high_keywords = ["must", "required", "audit", "authentication", "firewall"] + medium_keywords = ["should", "recommended", "configuration"] + if any(word in title_lower for word in critical_keywords): return "critical" elif any(word in title_lower for word in high_keywords): @@ -433,11 +559,11 @@ def _infer_severity(self, rule_title: str) -> str: return "medium" else: return "low" - + def _infer_nist_control(self, rule_title: str, control_family: str) -> Optional[str]: """Infer NIST control ID from rule title and family""" title_lower = rule_title.lower() - + # Common NIST control mappings nist_mappings = { "ssh": "AC-17", # Remote Access @@ -448,13 +574,13 @@ def _infer_nist_control(self, rule_title: str, control_family: str) -> Optional[ "update": "SI-2", # Flaw Remediation "encryption": "SC-13", # Cryptographic Protection } - + for keyword, control in nist_mappings.items(): if keyword in title_lower: return control - + return None - + def _check_automated_remediation(self, scap_rule_id: str) -> bool: """Check if automated remediation is available for this rule""" # This would check against AEGIS rule database @@ -464,7 +590,7 @@ def _check_automated_remediation(self, scap_rule_id: str) -> bool: "xccdf_mil.disa.stig_rule_SV-230423r793041_rule", # Audit daemon } return scap_rule_id in automated_rules - + def _get_aegis_rule_id(self, scap_rule_id: str) -> Optional[str]: """Get corresponding AEGIS rule ID for automated remediation""" aegis_mappings = { @@ -473,7 +599,7 @@ def _get_aegis_rule_id(self, scap_rule_id: str) -> Optional[str]: "xccdf_mil.disa.stig_rule_SV-230423r793041_rule": "auditd_service_enabled", } return aegis_mappings.get(scap_rule_id) - + def get_framework_summary(self, scap_rules: List[str]) -> Dict[str, Dict]: """Get compliance summary across all frameworks for a list of SCAP rules""" summary = { @@ -482,125 +608,126 @@ def get_framework_summary(self, scap_rules: List[str]) -> Dict[str, Dict]: "covered_controls": set(), "control_families": {}, "maturity_levels": {}, - "severity_distribution": {"critical": 0, "high": 0, "medium": 0, "low": 0} + "severity_distribution": {"critical": 0, "high": 0, "medium": 0, "low": 0}, } for framework in ComplianceFramework } - + for rule_id in scap_rules: control = self.get_unified_control(rule_id) if not control: continue - + for mapping in control.frameworks: framework_key = mapping.framework.value summary[framework_key]["total_controls"] += 1 summary[framework_key]["covered_controls"].add(mapping.control_id) - + # Count by control family family = mapping.control_family if family not in summary[framework_key]["control_families"]: summary[framework_key]["control_families"][family] = 0 summary[framework_key]["control_families"][family] += 1 - + # Count by maturity level (for CMMC) if mapping.maturity_level: level = f"Level {mapping.maturity_level}" if level not in summary[framework_key]["maturity_levels"]: summary[framework_key]["maturity_levels"][level] = 0 summary[framework_key]["maturity_levels"][level] += 1 - + # Count by severity summary[framework_key]["severity_distribution"][mapping.severity] += 1 - + # Convert sets to lists for JSON serialization for framework in summary.values(): framework["covered_controls"] = list(framework["covered_controls"]) - + return summary - + def get_remediation_priorities(self, failed_rules: List[Dict[str, str]]) -> List[Dict]: """Prioritize failed rules for remediation based on framework requirements""" priorities = [] - + for rule in failed_rules: rule_id = rule.get("rule_id", "") control = self.get_unified_control(rule_id) - + if not control: continue - + # Calculate priority score priority_score = 0 severity_scores = {"critical": 4, "high": 3, "medium": 2, "low": 1} - + # Get highest severity across frameworks max_severity = "low" frameworks_affected = [] - + for mapping in control.frameworks: if severity_scores.get(mapping.severity, 0) > severity_scores.get(max_severity, 0): max_severity = mapping.severity frameworks_affected.append(mapping.framework.value) - + # Add extra weight for CMMC Level 2+ requirements - if mapping.framework == ComplianceFramework.CMMC_2_0 and mapping.maturity_level >= 2: + if ( + mapping.framework == ComplianceFramework.CMMC_2_0 + and mapping.maturity_level >= 2 + ): priority_score += 10 - + priority_score += severity_scores.get(max_severity, 0) * 10 priority_score += len(frameworks_affected) * 5 - + # Boost priority if automated remediation is available if control.automated_remediation: priority_score += 20 - - priorities.append({ - "rule_id": rule_id, - "title": control.title, - "priority_score": priority_score, - "severity": max_severity, - "frameworks_affected": frameworks_affected, - "automated_remediation": control.automated_remediation, - "aegis_rule_id": control.aegis_rule_id, - "remediation_effort": self._estimate_remediation_effort(control) - }) - + + priorities.append( + { + "rule_id": rule_id, + "title": control.title, + "priority_score": priority_score, + "severity": max_severity, + "frameworks_affected": frameworks_affected, + "automated_remediation": control.automated_remediation, + "aegis_rule_id": control.aegis_rule_id, + "remediation_effort": self._estimate_remediation_effort(control), + } + ) + # Sort by priority score (highest first) priorities.sort(key=lambda x: x["priority_score"], reverse=True) - + return priorities - + def _estimate_remediation_effort(self, control: ComplianceControl) -> str: """Estimate remediation effort based on control characteristics""" if control.automated_remediation: return "minimal" - + # Check control categories if any(cat in ["Configuration Management", "Access Control"] for cat in control.categories): return "moderate" - elif any(cat in ["Audit and Accountability", "System and Information Integrity"] for cat in control.categories): + elif any( + cat in ["Audit and Accountability", "System and Information Integrity"] + for cat in control.categories + ): return "significant" else: return "moderate" - + def export_compliance_matrix(self, scap_rules: List[str]) -> Dict: """Export a compliance matrix showing coverage across all frameworks""" - matrix = { - "frameworks": list(ComplianceFramework.__members__.keys()), - "rules": [] - } - + matrix = {"frameworks": list(ComplianceFramework.__members__.keys()), "rules": []} + for rule_id in scap_rules: control = self.get_unified_control(rule_id) if not control: continue - - rule_entry = { - "rule_id": rule_id, - "title": control.title, - "mappings": {} - } - + + rule_entry = {"rule_id": rule_id, "title": control.title, "mappings": {}} + for framework in ComplianceFramework: framework_mappings = [m for m in control.frameworks if m.framework == framework] if framework_mappings: @@ -608,11 +735,11 @@ def export_compliance_matrix(self, scap_rules: List[str]) -> Dict: rule_entry["mappings"][framework.value] = { "control_id": mapping.control_id, "severity": mapping.severity, - "family": mapping.control_family + "family": mapping.control_family, } else: rule_entry["mappings"][framework.value] = None - + matrix["rules"].append(rule_entry) - - return matrix \ No newline at end of file + + return matrix diff --git a/backend/app/services/crypto.py b/backend/app/services/crypto.py index bc40fc86..1e38b556 100644 --- a/backend/app/services/crypto.py +++ b/backend/app/services/crypto.py @@ -2,6 +2,7 @@ Encryption/Decryption service for sensitive data FIPS-compliant AES-256-GCM encryption """ + import os import base64 from typing import Union @@ -24,7 +25,7 @@ def _derive_key(password: str, salt: bytes): length=32, salt=salt, iterations=100000, - backend=default_backend() + backend=default_backend(), ) return kdf.derive(password.encode()) @@ -38,24 +39,20 @@ def encrypt_credentials(data: str): # Generate random salt and nonce salt = os.urandom(16) nonce = os.urandom(12) - + # Derive key key = _derive_key(ENCRYPTION_KEY, salt) - + # Encrypt data - cipher = Cipher( - algorithms.AES(key), - modes.GCM(nonce), - backend=default_backend() - ) + cipher = Cipher(algorithms.AES(key), modes.GCM(nonce), backend=default_backend()) encryptor = cipher.encryptor() ciphertext = encryptor.update(data.encode()) + encryptor.finalize() - + # Combine salt + nonce + tag + ciphertext encrypted_data = salt + nonce + encryptor.tag + ciphertext - + return encrypted_data - + except Exception as e: logger.error(f"Encryption failed: {e}") raise ValueError("Failed to encrypt credentials") @@ -69,27 +66,23 @@ def decrypt_credentials(encrypted_data): try: if len(encrypted_data) < 44: # 16 (salt) + 12 (nonce) + 16 (tag) minimum raise ValueError("Invalid encrypted data length") - + # Extract components salt = encrypted_data[:16] nonce = encrypted_data[16:28] tag = encrypted_data[28:44] ciphertext = encrypted_data[44:] - + # Derive key key = _derive_key(ENCRYPTION_KEY, salt) - + # Decrypt data - cipher = Cipher( - algorithms.AES(key), - modes.GCM(nonce, tag), - backend=default_backend() - ) + cipher = Cipher(algorithms.AES(key), modes.GCM(nonce, tag), backend=default_backend()) decryptor = cipher.decryptor() plaintext = decryptor.update(ciphertext) + decryptor.finalize() - + return plaintext.decode() - + except Exception as e: logger.error(f"Decryption failed: {e}") raise ValueError("Failed to decrypt credentials") @@ -122,4 +115,4 @@ def verify_encryption() -> bool: return decrypted == test_data except Exception as e: logger.error(f"Encryption verification failed: {e}") - return False \ No newline at end of file + return False diff --git a/backend/app/services/csv_analyzer.py b/backend/app/services/csv_analyzer.py index db288e25..3351a456 100644 --- a/backend/app/services/csv_analyzer.py +++ b/backend/app/services/csv_analyzer.py @@ -2,6 +2,7 @@ CSV Analysis Service Provides intelligent analysis of CSV files for frictionless import """ + import csv import io import re @@ -13,6 +14,7 @@ class FieldType(Enum): """Detected field types with confidence scoring""" + HOSTNAME = "hostname" IP_ADDRESS = "ip_address" DISPLAY_NAME = "display_name" @@ -29,6 +31,7 @@ class FieldType(Enum): @dataclass class FieldAnalysis: """Analysis result for a CSV column""" + column_name: str detected_type: FieldType confidence: float # 0.0 to 1.0 @@ -41,6 +44,7 @@ class FieldAnalysis: @dataclass class CSVAnalysis: """Complete analysis result for a CSV file""" + total_rows: int total_columns: int headers: List[str] @@ -51,162 +55,158 @@ class CSVAnalysis: class CSVAnalyzer: """Intelligent CSV analysis for field mapping""" - + def __init__(self): self.ip_patterns = [ - re.compile(r'^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$'), # IPv4 - re.compile(r'^[0-9a-fA-F:]+$'), # IPv6 (simplified) + re.compile(r"^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$"), # IPv4 + re.compile(r"^[0-9a-fA-F:]+$"), # IPv6 (simplified) ] - + self.hostname_patterns = [ - re.compile(r'^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$'), - re.compile(r'^[a-zA-Z0-9\-_]+$'), # Simple hostname + re.compile( + r"^[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$" + ), + re.compile(r"^[a-zA-Z0-9\-_]+$"), # Simple hostname ] - + self.os_keywords = { - 'rhel': ['rhel', 'red hat', 'redhat'], - 'centos': ['centos', 'cent os'], - 'ubuntu': ['ubuntu'], - 'windows': ['windows', 'win'], - 'suse': ['suse', 'sles'], - 'debian': ['debian'], - 'linux': ['linux'], + "rhel": ["rhel", "red hat", "redhat"], + "centos": ["centos", "cent os"], + "ubuntu": ["ubuntu"], + "windows": ["windows", "win"], + "suse": ["suse", "sles"], + "debian": ["debian"], + "linux": ["linux"], } - + self.environment_keywords = { - 'production': ['prod', 'production', 'prd'], - 'staging': ['staging', 'stage', 'stg'], - 'development': ['dev', 'development', 'test'], - 'qa': ['qa', 'testing', 'qual'], + "production": ["prod", "production", "prd"], + "staging": ["staging", "stage", "stg"], + "development": ["dev", "development", "test"], + "qa": ["qa", "testing", "qual"], } - - self.auth_methods = ['password', 'ssh_key', 'system_default'] - + + self.auth_methods = ["password", "ssh_key", "system_default"] + # Common column name mappings self.column_mappings = { # Hostname variations - 'hostname': FieldType.HOSTNAME, - 'host_name': FieldType.HOSTNAME, - 'name': FieldType.HOSTNAME, - 'vm_name': FieldType.HOSTNAME, - 'machine_name': FieldType.HOSTNAME, - 'server_name': FieldType.HOSTNAME, - 'computer_name': FieldType.HOSTNAME, - + "hostname": FieldType.HOSTNAME, + "host_name": FieldType.HOSTNAME, + "name": FieldType.HOSTNAME, + "vm_name": FieldType.HOSTNAME, + "machine_name": FieldType.HOSTNAME, + "server_name": FieldType.HOSTNAME, + "computer_name": FieldType.HOSTNAME, # IP Address variations - 'ip': FieldType.IP_ADDRESS, - 'ip_address': FieldType.IP_ADDRESS, - 'ipv4': FieldType.IP_ADDRESS, - 'ipv6': FieldType.IP_ADDRESS, - 'address': FieldType.IP_ADDRESS, - 'host_ip': FieldType.IP_ADDRESS, - + "ip": FieldType.IP_ADDRESS, + "ip_address": FieldType.IP_ADDRESS, + "ipv4": FieldType.IP_ADDRESS, + "ipv6": FieldType.IP_ADDRESS, + "address": FieldType.IP_ADDRESS, + "host_ip": FieldType.IP_ADDRESS, # Display name variations - 'display_name': FieldType.DISPLAY_NAME, - 'description': FieldType.DISPLAY_NAME, - 'friendly_name': FieldType.DISPLAY_NAME, - 'label': FieldType.DISPLAY_NAME, - + "display_name": FieldType.DISPLAY_NAME, + "description": FieldType.DISPLAY_NAME, + "friendly_name": FieldType.DISPLAY_NAME, + "label": FieldType.DISPLAY_NAME, # Operating System variations - 'os': FieldType.OPERATING_SYSTEM, - 'operating_system': FieldType.OPERATING_SYSTEM, - 'os_type': FieldType.OPERATING_SYSTEM, - 'guest_os': FieldType.OPERATING_SYSTEM, - 'platform': FieldType.OPERATING_SYSTEM, - + "os": FieldType.OPERATING_SYSTEM, + "operating_system": FieldType.OPERATING_SYSTEM, + "os_type": FieldType.OPERATING_SYSTEM, + "guest_os": FieldType.OPERATING_SYSTEM, + "platform": FieldType.OPERATING_SYSTEM, # Port variations - 'port': FieldType.PORT, - 'ssh_port': FieldType.PORT, - 'port_ssh': FieldType.PORT, - + "port": FieldType.PORT, + "ssh_port": FieldType.PORT, + "port_ssh": FieldType.PORT, # Username variations - 'user': FieldType.USERNAME, - 'username': FieldType.USERNAME, - 'ssh_user': FieldType.USERNAME, - 'admin_user': FieldType.USERNAME, - + "user": FieldType.USERNAME, + "username": FieldType.USERNAME, + "ssh_user": FieldType.USERNAME, + "admin_user": FieldType.USERNAME, # Environment variations - 'env': FieldType.ENVIRONMENT, - 'environment': FieldType.ENVIRONMENT, - 'stage': FieldType.ENVIRONMENT, - + "env": FieldType.ENVIRONMENT, + "environment": FieldType.ENVIRONMENT, + "stage": FieldType.ENVIRONMENT, # Tags variations - 'tags': FieldType.TAGS, - 'labels': FieldType.TAGS, - 'categories': FieldType.TAGS, - + "tags": FieldType.TAGS, + "labels": FieldType.TAGS, + "categories": FieldType.TAGS, # Owner variations - 'owner': FieldType.OWNER, - 'responsible': FieldType.OWNER, - 'contact': FieldType.OWNER, - 'admin': FieldType.OWNER, + "owner": FieldType.OWNER, + "responsible": FieldType.OWNER, + "contact": FieldType.OWNER, + "admin": FieldType.OWNER, } def analyze_csv(self, csv_content: str, max_preview_rows: int = 10) -> CSVAnalysis: """Analyze CSV content and provide intelligent field mapping suggestions""" - + if not csv_content or not csv_content.strip(): raise ValueError("CSV content is empty") - + try: # Parse CSV csv_reader = csv.DictReader(io.StringIO(csv_content)) headers = csv_reader.fieldnames or [] - + if not headers: raise ValueError("CSV file has no headers or is malformed") - + # Read all rows for analysis rows = list(csv_reader) total_rows = len(rows) - + if total_rows == 0: raise ValueError("CSV file is empty or has no data rows") - + except csv.Error as e: raise ValueError(f"Invalid CSV format: {e}") except Exception as e: raise ValueError(f"Error parsing CSV: {e}") - + # Analyze each column field_analyses = [] for header in headers: analysis = self._analyze_column(header, rows, max_preview_rows) field_analyses.append(analysis) - + # Generate auto-mappings auto_mappings = self._generate_auto_mappings(field_analyses) - + # Check for template matches template_matches = self._check_template_matches(headers) - + return CSVAnalysis( total_rows=total_rows, total_columns=len(headers), headers=headers, field_analyses=field_analyses, auto_mappings=auto_mappings, - template_matches=template_matches + template_matches=template_matches, ) - def _analyze_column(self, column_name: str, rows: List[Dict], max_preview: int) -> FieldAnalysis: + def _analyze_column( + self, column_name: str, rows: List[Dict], max_preview: int + ) -> FieldAnalysis: """Analyze a single column and detect its likely type""" - + # Extract values for this column, handling None values - values = [(row.get(column_name) or '').strip() for row in rows] + values = [(row.get(column_name) or "").strip() for row in rows] non_empty_values = [v for v in values if v] - + # Basic stats unique_count = len(set(non_empty_values)) null_count = len(values) - len(non_empty_values) sample_values = non_empty_values[:max_preview] - + # Detect field type detected_type, confidence = self._detect_field_type(column_name, non_empty_values) - + # Generate suggestions suggestions = self._generate_suggestions(detected_type, non_empty_values) - + return FieldAnalysis( column_name=column_name, detected_type=detected_type, @@ -214,20 +214,20 @@ def _analyze_column(self, column_name: str, rows: List[Dict], max_preview: int) sample_values=sample_values, unique_count=unique_count, null_count=null_count, - suggestions=suggestions + suggestions=suggestions, ) def _detect_field_type(self, column_name: str, values: List[str]) -> Tuple[FieldType, float]: """Detect the most likely field type for a column""" - + if not values: return FieldType.UNKNOWN, 0.0 - + # Check column name first (high confidence) - normalized_name = column_name.lower().replace(' ', '_').replace('-', '_') + normalized_name = column_name.lower().replace(" ", "_").replace("-", "_") if normalized_name in self.column_mappings: return self.column_mappings[normalized_name], 0.95 - + # Content-based detection detectors = [ (self._is_ip_address_column, FieldType.IP_ADDRESS), @@ -237,37 +237,37 @@ def _detect_field_type(self, column_name: str, values: List[str]) -> Tuple[Field (self._is_environment_column, FieldType.ENVIRONMENT), (self._is_auth_method_column, FieldType.AUTH_METHOD), ] - + best_type = FieldType.UNKNOWN best_confidence = 0.0 - + for detector, field_type in detectors: confidence = detector(values) if confidence > best_confidence: best_type = field_type best_confidence = confidence - + # Fallback heuristics if best_confidence < 0.3: - if 'name' in normalized_name: + if "name" in normalized_name: if len(set(values)) / len(values) > 0.8: # High uniqueness return FieldType.HOSTNAME, 0.6 else: return FieldType.DISPLAY_NAME, 0.6 - elif 'user' in normalized_name: + elif "user" in normalized_name: return FieldType.USERNAME, 0.6 - elif 'tag' in normalized_name or 'label' in normalized_name: + elif "tag" in normalized_name or "label" in normalized_name: return FieldType.TAGS, 0.6 - elif 'owner' in normalized_name or 'contact' in normalized_name: + elif "owner" in normalized_name or "contact" in normalized_name: return FieldType.OWNER, 0.6 - + return best_type, best_confidence def _is_ip_address_column(self, values: List[str]) -> float: """Check if column contains IP addresses""" if not values: return 0.0 - + valid_ips = 0 for value in values[:20]: # Sample first 20 values try: @@ -275,26 +275,26 @@ def _is_ip_address_column(self, values: List[str]) -> float: valid_ips += 1 except ValueError: pass - + return valid_ips / min(len(values), 20) def _is_hostname_column(self, values: List[str]) -> float: """Check if column contains hostnames""" if not values: return 0.0 - + valid_hostnames = 0 for value in values[:20]: if any(pattern.match(value) for pattern in self.hostname_patterns): valid_hostnames += 1 - + return valid_hostnames / min(len(values), 20) def _is_port_column(self, values: List[str]) -> float: """Check if column contains port numbers""" if not values: return 0.0 - + valid_ports = 0 for value in values[:20]: try: @@ -303,14 +303,14 @@ def _is_port_column(self, values: List[str]) -> float: valid_ports += 1 except ValueError: pass - + return valid_ports / min(len(values), 20) def _is_os_column(self, values: List[str]) -> float: """Check if column contains operating system names""" if not values: return 0.0 - + os_matches = 0 for value in values[:20]: value_lower = value.lower() @@ -318,14 +318,14 @@ def _is_os_column(self, values: List[str]) -> float: if any(keyword in value_lower for keyword in keywords): os_matches += 1 break - + return os_matches / min(len(values), 20) def _is_environment_column(self, values: List[str]) -> float: """Check if column contains environment names""" if not values: return 0.0 - + env_matches = 0 for value in values[:20]: value_lower = value.lower() @@ -333,95 +333,97 @@ def _is_environment_column(self, values: List[str]) -> float: if any(keyword in value_lower for keyword in keywords): env_matches += 1 break - + return env_matches / min(len(values), 20) def _is_auth_method_column(self, values: List[str]) -> float: """Check if column contains authentication methods""" if not values: return 0.0 - + auth_matches = 0 for value in values[:20]: if value.lower() in self.auth_methods: auth_matches += 1 - + return auth_matches / min(len(values), 20) def _generate_suggestions(self, field_type: FieldType, values: List[str]) -> List[str]: """Generate helpful suggestions for field mapping""" suggestions = [] - + if field_type == FieldType.IP_ADDRESS: # Check for IPv6 addresses - has_ipv6 = any(':' in v for v in values[:10]) + has_ipv6 = any(":" in v for v in values[:10]) if has_ipv6: suggestions.append("Contains IPv6 addresses - ensure proper formatting") - + elif field_type == FieldType.PORT: unique_ports = set(values[:20]) if len(unique_ports) == 1: suggestions.append(f"All hosts use port {list(unique_ports)[0]}") - + elif field_type == FieldType.OPERATING_SYSTEM: unique_os = set(v.lower() for v in values[:20]) if len(unique_os) <= 3: suggestions.append(f"Limited OS variety: {', '.join(unique_os)}") - + elif field_type == FieldType.ENVIRONMENT: unique_envs = set(v.lower() for v in values[:20]) suggestions.append(f"Environments detected: {', '.join(unique_envs)}") - + return suggestions def _generate_auto_mappings(self, field_analyses: List[FieldAnalysis]) -> Dict[str, str]: """Generate automatic field mappings based on confidence scores""" mappings = {} - + # Required fields that should be mapped required_fields = [FieldType.HOSTNAME, FieldType.IP_ADDRESS] - + # Track which target fields have been assigned assigned_targets = set() - + # Sort by confidence (highest first) sorted_analyses = sorted(field_analyses, key=lambda x: x.confidence, reverse=True) - + for analysis in sorted_analyses: if analysis.confidence >= 0.7 and analysis.detected_type != FieldType.UNKNOWN: target_field = analysis.detected_type.value - + # Avoid duplicate mappings if target_field not in assigned_targets: mappings[analysis.column_name] = target_field assigned_targets.add(target_field) - + return mappings def _check_template_matches(self, headers: List[str]) -> List[str]: """Check if headers match known source templates""" templates = [] - + headers_lower = [h.lower() for h in headers] - + # VMware vCenter patterns - vmware_indicators = ['vm_name', 'guest_os', 'ip_address', 'power_state'] - if any(indicator in ' '.join(headers_lower) for indicator in vmware_indicators): + vmware_indicators = ["vm_name", "guest_os", "ip_address", "power_state"] + if any(indicator in " ".join(headers_lower) for indicator in vmware_indicators): templates.append("VMware vCenter Export") - + # Red Hat Satellite patterns - satellite_indicators = ['name', 'operating_system', 'ip', 'environment'] - if all(any(indicator in h for h in headers_lower) for indicator in satellite_indicators[:2]): + satellite_indicators = ["name", "operating_system", "ip", "environment"] + if all( + any(indicator in h for h in headers_lower) for indicator in satellite_indicators[:2] + ): templates.append("Red Hat Satellite Export") - + # AWS EC2 patterns - aws_indicators = ['instance_id', 'instance_type', 'public_ip', 'private_ip'] - if any(indicator in ' '.join(headers_lower) for indicator in aws_indicators): + aws_indicators = ["instance_id", "instance_type", "public_ip", "private_ip"] + if any(indicator in " ".join(headers_lower) for indicator in aws_indicators): templates.append("AWS EC2 Instance List") - + # Azure VM patterns - azure_indicators = ['vm_name', 'resource_group', 'location', 'vm_size'] - if any(indicator in ' '.join(headers_lower) for indicator in azure_indicators): + azure_indicators = ["vm_name", "resource_group", "location", "vm_size"] + if any(indicator in " ".join(headers_lower) for indicator in azure_indicators): templates.append("Azure VM Export") - - return templates \ No newline at end of file + + return templates diff --git a/backend/app/services/email_service.py b/backend/app/services/email_service.py index 58a95657..6e08f30d 100644 --- a/backend/app/services/email_service.py +++ b/backend/app/services/email_service.py @@ -1,6 +1,7 @@ """ Email Service for sending notifications """ + import aiosmtplib import logging from email.mime.text import MIMEText @@ -14,28 +15,24 @@ class EmailService: def __init__(self): - self.smtp_host = os.getenv('SMTP_HOST', 'localhost') - self.smtp_port = int(os.getenv('SMTP_PORT', '587')) - self.smtp_username = os.getenv('SMTP_USERNAME', '') - self.smtp_password = os.getenv('SMTP_PASSWORD', '') - self.smtp_use_tls = os.getenv('SMTP_USE_TLS', 'true').lower() == 'true' - self.from_email = os.getenv('FROM_EMAIL', 'openwatch@example.com') - self.from_name = os.getenv('FROM_NAME', 'OpenWatch Security Scanner') - + self.smtp_host = os.getenv("SMTP_HOST", "localhost") + self.smtp_port = int(os.getenv("SMTP_PORT", "587")) + self.smtp_username = os.getenv("SMTP_USERNAME", "") + self.smtp_password = os.getenv("SMTP_PASSWORD", "") + self.smtp_use_tls = os.getenv("SMTP_USE_TLS", "true").lower() == "true" + self.from_email = os.getenv("FROM_EMAIL", "openwatch@example.com") + self.from_name = os.getenv("FROM_NAME", "OpenWatch Security Scanner") + async def send_host_offline_alert( - self, - host_name: str, - host_ip: str, - last_check: datetime, - recipients: List[str] + self, host_name: str, host_ip: str, last_check: datetime, recipients: List[str] ) -> bool: """Send host offline alert email""" if not recipients: logger.warning("No recipients provided for host offline alert") return False - + subject = f"🚨 Host Offline Alert: {host_name}" - + # Create HTML body html_body = f""" @@ -86,7 +83,7 @@ async def send_host_offline_alert( """ - + # Create plain text body as fallback plain_body = f""" HOST OFFLINE ALERT @@ -108,23 +105,19 @@ async def send_host_offline_alert( --- This is an automated message from OpenWatch Security Scanner. """ - + return await self._send_email(recipients, subject, plain_body, html_body) - + async def send_host_online_alert( - self, - host_name: str, - host_ip: str, - check_time: datetime, - recipients: List[str] + self, host_name: str, host_ip: str, check_time: datetime, recipients: List[str] ) -> bool: """Send host back online alert email""" if not recipients: logger.warning("No recipients provided for host online alert") return False - + subject = f"✅ Host Online: {host_name}" - + # Create HTML body html_body = f""" @@ -167,7 +160,7 @@ async def send_host_online_alert( """ - + # Create plain text body as fallback plain_body = f""" HOST BACK ONLINE @@ -184,15 +177,11 @@ async def send_host_online_alert( --- This is an automated message from OpenWatch Security Scanner. """ - + return await self._send_email(recipients, subject, plain_body, html_body) - + async def _send_email( - self, - recipients: List[str], - subject: str, - plain_body: str, - html_body: Optional[str] = None + self, recipients: List[str], subject: str, plain_body: str, html_body: Optional[str] = None ) -> bool: """Send email using SMTP""" try: @@ -200,20 +189,20 @@ async def _send_email( if not self.smtp_host or not self.from_email: logger.warning("Email not configured (missing SMTP_HOST or FROM_EMAIL)") return False - + # Create message - msg = MIMEMultipart('alternative') - msg['Subject'] = subject - msg['From'] = f"{self.from_name} <{self.from_email}>" - msg['To'] = ', '.join(recipients) - + msg = MIMEMultipart("alternative") + msg["Subject"] = subject + msg["From"] = f"{self.from_name} <{self.from_email}>" + msg["To"] = ", ".join(recipients) + # Add plain text part - msg.attach(MIMEText(plain_body, 'plain')) - + msg.attach(MIMEText(plain_body, "plain")) + # Add HTML part if provided if html_body: - msg.attach(MIMEText(html_body, 'html')) - + msg.attach(MIMEText(html_body, "html")) + # Connect and send if self.smtp_use_tls: await aiosmtplib.send( @@ -222,7 +211,7 @@ async def _send_email( port=self.smtp_port, username=self.smtp_username if self.smtp_username else None, password=self.smtp_password if self.smtp_password else None, - use_tls=True + use_tls=True, ) else: await aiosmtplib.send( @@ -230,16 +219,16 @@ async def _send_email( hostname=self.smtp_host, port=self.smtp_port, username=self.smtp_username if self.smtp_username else None, - password=self.smtp_password if self.smtp_password else None + password=self.smtp_password if self.smtp_password else None, ) - + logger.info(f"Email sent successfully to {len(recipients)} recipients: {subject}") return True - + except Exception as e: logger.error(f"Failed to send email: {e}") return False # Global email service instance -email_service = EmailService() \ No newline at end of file +email_service = EmailService() diff --git a/backend/app/services/encryption.py b/backend/app/services/encryption.py index 088dd9da..a4e8020f 100644 --- a/backend/app/services/encryption.py +++ b/backend/app/services/encryption.py @@ -2,6 +2,7 @@ Encryption service for sensitive data like SSH credentials Uses AES-256-GCM for authenticated encryption """ + import os import base64 from cryptography.hazmat.primitives.ciphers.aead import AESGCM @@ -11,11 +12,12 @@ logger = logging.getLogger(__name__) + class EncryptionService: def __init__(self, master_key: str): """Initialize encryption service with master key""" self.master_key = master_key.encode() - + def _derive_key(self, salt: bytes) -> bytes: """Derive encryption key from master key and salt""" kdf = PBKDF2HMAC( @@ -25,30 +27,30 @@ def _derive_key(self, salt: bytes) -> bytes: iterations=100000, ) return kdf.derive(self.master_key) - + def encrypt(self, data: bytes) -> bytes: """Encrypt data using AES-256-GCM""" try: # Generate random salt and nonce salt = os.urandom(16) nonce = os.urandom(12) # GCM recommended nonce size - + # Derive key from master key and salt key = self._derive_key(salt) - + # Encrypt data aesgcm = AESGCM(key) ciphertext = aesgcm.encrypt(nonce, data, None) - + # Combine salt + nonce + ciphertext encrypted_data = salt + nonce + ciphertext - + return encrypted_data - + except Exception as e: logger.error(f"Encryption error: {e}") raise - + def decrypt(self, encrypted_data: bytes) -> bytes: """Decrypt data using AES-256-GCM""" try: @@ -56,38 +58,43 @@ def decrypt(self, encrypted_data: bytes) -> bytes: salt = encrypted_data[:16] nonce = encrypted_data[16:28] ciphertext = encrypted_data[28:] - + # Derive key from master key and salt key = self._derive_key(salt) - + # Decrypt data aesgcm = AESGCM(key) plaintext = aesgcm.decrypt(nonce, ciphertext, None) - + return plaintext - + except Exception as e: logger.error(f"Decryption error: {e}") raise + # Global encryption service instance _encryption_service = None + def get_encryption_service() -> EncryptionService: """Get global encryption service instance""" global _encryption_service if _encryption_service is None: from ..config import get_settings + settings = get_settings() _encryption_service = EncryptionService(settings.master_key) return _encryption_service + def encrypt_data(data: bytes) -> str: """Encrypt data using global encryption service and return base64 string""" encrypted_bytes = get_encryption_service().encrypt(data) - return base64.b64encode(encrypted_bytes).decode('ascii') + return base64.b64encode(encrypted_bytes).decode("ascii") + def decrypt_data(encrypted_data: str) -> bytes: """Decrypt base64-encoded data using global encryption service""" - encrypted_bytes = base64.b64decode(encrypted_data.encode('ascii')) - return get_encryption_service().decrypt(encrypted_bytes) \ No newline at end of file + encrypted_bytes = base64.b64decode(encrypted_data.encode("ascii")) + return get_encryption_service().decrypt(encrypted_bytes) diff --git a/backend/app/services/error_classification.py b/backend/app/services/error_classification.py index a2025710..5894c087 100644 --- a/backend/app/services/error_classification.py +++ b/backend/app/services/error_classification.py @@ -3,6 +3,7 @@ Provides comprehensive error taxonomy and user-friendly guidance Enhanced with security sanitization to prevent information disclosure """ + import socket import os import json @@ -14,8 +15,12 @@ from pydantic import BaseModel, Field from ..models.error_models import ( - ScanErrorInternal, ScanErrorResponse, ValidationResultInternal, ValidationResultResponse, - ErrorCategory, ErrorSeverity + ScanErrorInternal, + ScanErrorResponse, + ValidationResultInternal, + ValidationResultResponse, + ErrorCategory, + ErrorSeverity, ) from .error_sanitization import get_error_sanitization_service, SanitizationLevel from .security_audit_logger import get_security_audit_logger @@ -30,6 +35,7 @@ class AutomatedFix(BaseModel): """Represents an automated fix option""" + fix_id: str description: str requires_sudo: bool = False @@ -49,145 +55,163 @@ class AutomatedFix(BaseModel): class NetworkValidator: """Network connectivity validation""" - + @staticmethod async def validate_connectivity(hostname: str, port: int = 22) -> List[ScanErrorInternal]: """Comprehensive network connectivity validation""" errors = [] - + try: # Stage 1: DNS Resolution try: ip_address = socket.gethostbyname(hostname) logger.debug(f"DNS resolution successful: {hostname} -> {ip_address}") except socket.gaierror as e: - errors.append(ScanErrorInternal( - error_code="NET_001", - category=ErrorCategory.NETWORK, - severity=ErrorSeverity.ERROR, - message=f"DNS resolution failed for {hostname}", - technical_details={"hostname": hostname, "error": str(e)}, - user_guidance="Verify the hostname is correct or use an IP address directly. Check your DNS server configuration.", - automated_fixes=[ - AutomatedFix( - fix_id="use_ip_address", - description="Use IP address instead of hostname", - requires_sudo=False, - estimated_time=5 - ) - ], - can_retry=True, - retry_after=30, - documentation_url="https://docs.openwatch.dev/troubleshooting/network#dns-resolution" - )) - return errors - - # Stage 2: TCP Connection Test - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(10) - try: - result = sock.connect_ex((ip_address, port)) - if result != 0: - errors.append(ScanErrorInternal( - error_code="NET_002", + errors.append( + ScanErrorInternal( + error_code="NET_001", category=ErrorCategory.NETWORK, severity=ErrorSeverity.ERROR, - message=f"Cannot connect to {hostname}:{port}", - technical_details={"hostname": hostname, "port": port, "ip_address": ip_address, "connection_result": result}, - user_guidance=f"Check if SSH service is running on port {port} and firewall rules allow connections.", + message=f"DNS resolution failed for {hostname}", + technical_details={"hostname": hostname, "error": str(e)}, + user_guidance="Verify the hostname is correct or use an IP address directly. Check your DNS server configuration.", automated_fixes=[ AutomatedFix( - fix_id="check_firewall", - description="Check firewall rules for SSH port", - command=f"netstat -tlnp | grep {port}", + fix_id="use_ip_address", + description="Use IP address instead of hostname", requires_sudo=False, - estimated_time=10, - is_safe=True # Read-only command is safe + estimated_time=5, ) ], can_retry=True, - retry_after=60, - documentation_url="https://docs.openwatch.dev/troubleshooting/network#connection-refused" - )) + retry_after=30, + documentation_url="https://docs.openwatch.dev/troubleshooting/network#dns-resolution", + ) + ) + return errors + + # Stage 2: TCP Connection Test + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(10) + try: + result = sock.connect_ex((ip_address, port)) + if result != 0: + errors.append( + ScanErrorInternal( + error_code="NET_002", + category=ErrorCategory.NETWORK, + severity=ErrorSeverity.ERROR, + message=f"Cannot connect to {hostname}:{port}", + technical_details={ + "hostname": hostname, + "port": port, + "ip_address": ip_address, + "connection_result": result, + }, + user_guidance=f"Check if SSH service is running on port {port} and firewall rules allow connections.", + automated_fixes=[ + AutomatedFix( + fix_id="check_firewall", + description="Check firewall rules for SSH port", + command=f"netstat -tlnp | grep {port}", + requires_sudo=False, + estimated_time=10, + is_safe=True, # Read-only command is safe + ) + ], + can_retry=True, + retry_after=60, + documentation_url="https://docs.openwatch.dev/troubleshooting/network#connection-refused", + ) + ) return errors except socket.timeout: - errors.append(ScanErrorInternal( - error_code="NET_003", - category=ErrorCategory.NETWORK, - severity=ErrorSeverity.ERROR, - message=f"Connection timeout to {hostname}:{port}", - technical_details={"hostname": hostname, "port": port, "timeout": 10}, - user_guidance="Host may be unreachable or behind a firewall. Check network connectivity and firewall rules.", - can_retry=True, - retry_after=120, - documentation_url="https://docs.openwatch.dev/troubleshooting/network#timeout" - )) + errors.append( + ScanErrorInternal( + error_code="NET_003", + category=ErrorCategory.NETWORK, + severity=ErrorSeverity.ERROR, + message=f"Connection timeout to {hostname}:{port}", + technical_details={"hostname": hostname, "port": port, "timeout": 10}, + user_guidance="Host may be unreachable or behind a firewall. Check network connectivity and firewall rules.", + can_retry=True, + retry_after=120, + documentation_url="https://docs.openwatch.dev/troubleshooting/network#timeout", + ) + ) return errors finally: sock.close() - + # Stage 3: SSH Banner Check try: transport = paramiko.Transport((hostname, port)) transport.start_client(timeout=5) banner = transport.get_banner() transport.close() - - if banner and b'ssh' not in banner.lower(): - errors.append(ScanErrorInternal( - error_code="NET_004", + + if banner and b"ssh" not in banner.lower(): + errors.append( + ScanErrorInternal( + error_code="NET_004", + category=ErrorCategory.NETWORK, + severity=ErrorSeverity.WARNING, + message=f"Unexpected service on port {port}", + technical_details={"banner": banner.decode("utf-8", errors="ignore")}, + user_guidance=f"Port {port} is not running SSH service. Verify SSH daemon is running on the correct port.", + documentation_url="https://docs.openwatch.dev/troubleshooting/network#wrong-service", + ) + ) + + except Exception as e: + errors.append( + ScanErrorInternal( + error_code="NET_005", category=ErrorCategory.NETWORK, severity=ErrorSeverity.WARNING, - message=f"Unexpected service on port {port}", - technical_details={"banner": banner.decode('utf-8', errors='ignore')}, - user_guidance=f"Port {port} is not running SSH service. Verify SSH daemon is running on the correct port.", - documentation_url="https://docs.openwatch.dev/troubleshooting/network#wrong-service" - )) - - except Exception as e: - errors.append(ScanErrorInternal( - error_code="NET_005", + message="SSH service not responding properly", + technical_details={"error": str(e)}, + user_guidance="SSH daemon may not be running or configured properly. Check SSH service status.", + automated_fixes=[ + AutomatedFix( + fix_id="check_ssh_service", + description="Check SSH service status", + command="systemctl status sshd", + requires_sudo=False, + estimated_time=5, + is_safe=True, # Read-only command is safe + ) + ], + documentation_url="https://docs.openwatch.dev/troubleshooting/network#ssh-daemon", + ) + ) + + except Exception as e: + errors.append( + ScanErrorInternal( + error_code="NET_999", category=ErrorCategory.NETWORK, - severity=ErrorSeverity.WARNING, - message="SSH service not responding properly", + severity=ErrorSeverity.ERROR, + message="Unexpected network validation error", technical_details={"error": str(e)}, - user_guidance="SSH daemon may not be running or configured properly. Check SSH service status.", - automated_fixes=[ - AutomatedFix( - fix_id="check_ssh_service", - description="Check SSH service status", - command="systemctl status sshd", - requires_sudo=False, - estimated_time=5, - is_safe=True # Read-only command is safe - ) - ], - documentation_url="https://docs.openwatch.dev/troubleshooting/network#ssh-daemon" - )) - - except Exception as e: - errors.append(ScanErrorInternal( - error_code="NET_999", - category=ErrorCategory.NETWORK, - severity=ErrorSeverity.ERROR, - message="Unexpected network validation error", - technical_details={"error": str(e)}, - user_guidance="An unexpected error occurred during network validation. Please check logs and try again.", - can_retry=True - )) - + user_guidance="An unexpected error occurred during network validation. Please check logs and try again.", + can_retry=True, + ) + ) + return errors class AuthenticationValidator: """SSH authentication validation""" - + @staticmethod - async def validate_credentials(hostname: str, port: int, username: str, - auth_method: str, credential: str) -> List[ScanErrorInternal]: + async def validate_credentials( + hostname: str, port: int, username: str, auth_method: str, credential: str + ) -> List[ScanErrorInternal]: """Validate SSH authentication credentials""" errors = [] - + try: ssh = paramiko.SSHClient() # Security Fix: Use strict host key checking instead of AutoAddPolicy @@ -195,367 +219,423 @@ async def validate_credentials(hostname: str, port: int, username: str, # Load system and user host keys for validation try: ssh.load_system_host_keys() - ssh.load_host_keys(os.path.expanduser('~/.ssh/known_hosts')) + ssh.load_host_keys(os.path.expanduser("~/.ssh/known_hosts")) except FileNotFoundError: - logger.warning("No known_hosts files found - SSH connections may fail without proper host key management") - + logger.warning( + "No known_hosts files found - SSH connections may fail without proper host key management" + ) + if auth_method == "password": try: - ssh.connect(hostname, port=port, username=username, - password=credential, timeout=10) + ssh.connect( + hostname, port=port, username=username, password=credential, timeout=10 + ) ssh.close() return errors # Success except paramiko.AuthenticationException as e: error_msg = str(e).lower() if "too many authentication failures" in error_msg: - errors.append(ScanErrorInternal( - error_code="AUTH_001", - category=ErrorCategory.AUTHENTICATION, - severity=ErrorSeverity.ERROR, - message="Account temporarily locked due to failed login attempts", - technical_details={"username": username, "hostname": hostname}, - user_guidance="Wait for account lockout to expire or contact system administrator to unlock the account.", - can_retry=True, - retry_after=900, # 15 minutes - documentation_url="https://docs.openwatch.dev/troubleshooting/auth#account-locked" - )) + errors.append( + ScanErrorInternal( + error_code="AUTH_001", + category=ErrorCategory.AUTHENTICATION, + severity=ErrorSeverity.ERROR, + message="Account temporarily locked due to failed login attempts", + technical_details={"username": username, "hostname": hostname}, + user_guidance="Wait for account lockout to expire or contact system administrator to unlock the account.", + can_retry=True, + retry_after=900, # 15 minutes + documentation_url="https://docs.openwatch.dev/troubleshooting/auth#account-locked", + ) + ) else: - errors.append(ScanErrorInternal( - error_code="AUTH_002", - category=ErrorCategory.AUTHENTICATION, - severity=ErrorSeverity.ERROR, - message="Invalid username or password", - technical_details={"username": username, "auth_method": auth_method}, - user_guidance="Verify the username and password are correct. Check if account is disabled or expired.", - automated_fixes=[ - AutomatedFix( - fix_id="test_password_reset", - description="Test password reset if available", - requires_sudo=False, - estimated_time=60 - ) - ], - documentation_url="https://docs.openwatch.dev/troubleshooting/auth#invalid-credentials" - )) - + errors.append( + ScanErrorInternal( + error_code="AUTH_002", + category=ErrorCategory.AUTHENTICATION, + severity=ErrorSeverity.ERROR, + message="Invalid username or password", + technical_details={ + "username": username, + "auth_method": auth_method, + }, + user_guidance="Verify the username and password are correct. Check if account is disabled or expired.", + automated_fixes=[ + AutomatedFix( + fix_id="test_password_reset", + description="Test password reset if available", + requires_sudo=False, + estimated_time=60, + ) + ], + documentation_url="https://docs.openwatch.dev/troubleshooting/auth#invalid-credentials", + ) + ) + elif auth_method in ["ssh_key", "ssh-key"]: try: from .ssh_utils import parse_ssh_key, validate_ssh_key - + # First validate the key format validation_result = validate_ssh_key(credential) if not validation_result.is_valid: - errors.append(ScanErrorInternal( - error_code="AUTH_003", + errors.append( + ScanErrorInternal( + error_code="AUTH_003", + category=ErrorCategory.AUTHENTICATION, + severity=ErrorSeverity.ERROR, + message=f"Invalid SSH key format: {validation_result.error_message}", + technical_details={ + "validation_error": validation_result.error_message + }, + user_guidance="Ensure SSH private key is in correct format (RSA, DSA, ECDSA, or Ed25519). Check key file integrity.", + automated_fixes=[ + AutomatedFix( + fix_id="regenerate_key", + description="⚠️ SECURITY: Use secure automated fix system to generate SSH key", + command=None, # No direct command - use secure system + requires_sudo=False, + estimated_time=30, + is_safe=False, # Mark as unsafe for direct execution + ) + ], + documentation_url="https://docs.openwatch.dev/troubleshooting/auth#invalid-key-format", + ) + ) + return errors + + # Parse and test the key + key = parse_ssh_key(credential) + ssh.connect(hostname, port=port, username=username, pkey=key, timeout=10) + ssh.close() + return errors # Success + + except paramiko.AuthenticationException: + errors.append( + ScanErrorInternal( + error_code="AUTH_004", category=ErrorCategory.AUTHENTICATION, severity=ErrorSeverity.ERROR, - message=f"Invalid SSH key format: {validation_result.error_message}", - technical_details={"validation_error": validation_result.error_message}, - user_guidance="Ensure SSH private key is in correct format (RSA, DSA, ECDSA, or Ed25519). Check key file integrity.", + message="SSH key authentication failed", + technical_details={"username": username, "auth_method": auth_method}, + user_guidance="SSH public key not authorized for this user. Add public key to ~/.ssh/authorized_keys on target host.", automated_fixes=[ AutomatedFix( - fix_id="regenerate_key", - description="⚠️ SECURITY: Use secure automated fix system to generate SSH key", + fix_id="copy_public_key", + description="⚠️ SECURITY: Use secure automated fix system to copy SSH key", command=None, # No direct command - use secure system requires_sudo=False, - estimated_time=30, - is_safe=False # Mark as unsafe for direct execution + estimated_time=15, + is_safe=False, # Mark as unsafe for direct execution ) ], - documentation_url="https://docs.openwatch.dev/troubleshooting/auth#invalid-key-format" - )) - return errors - - # Parse and test the key - key = parse_ssh_key(credential) - ssh.connect(hostname, port=port, username=username, pkey=key, timeout=10) - ssh.close() - return errors # Success - - except paramiko.AuthenticationException: - errors.append(ScanErrorInternal( - error_code="AUTH_004", - category=ErrorCategory.AUTHENTICATION, - severity=ErrorSeverity.ERROR, - message="SSH key authentication failed", - technical_details={"username": username, "auth_method": auth_method}, - user_guidance="SSH public key not authorized for this user. Add public key to ~/.ssh/authorized_keys on target host.", - automated_fixes=[ - AutomatedFix( - fix_id="copy_public_key", - description="⚠️ SECURITY: Use secure automated fix system to copy SSH key", - command=None, # No direct command - use secure system - requires_sudo=False, - estimated_time=15, - is_safe=False # Mark as unsafe for direct execution - ) - ], - documentation_url="https://docs.openwatch.dev/troubleshooting/auth#key-not-authorized" - )) - + documentation_url="https://docs.openwatch.dev/troubleshooting/auth#key-not-authorized", + ) + ) + except Exception as e: - errors.append(ScanErrorInternal( - error_code="AUTH_999", - category=ErrorCategory.AUTHENTICATION, - severity=ErrorSeverity.ERROR, - message="Unexpected authentication error", - technical_details={"error": str(e)}, - user_guidance="An unexpected error occurred during authentication validation. Check network connectivity and try again.", - can_retry=True - )) - + errors.append( + ScanErrorInternal( + error_code="AUTH_999", + category=ErrorCategory.AUTHENTICATION, + severity=ErrorSeverity.ERROR, + message="Unexpected authentication error", + technical_details={"error": str(e)}, + user_guidance="An unexpected error occurred during authentication validation. Check network connectivity and try again.", + can_retry=True, + ) + ) + return errors class PrivilegeValidator: """System privilege validation""" - + @staticmethod async def validate_privileges(ssh_client: paramiko.SSHClient) -> List[ScanErrorInternal]: """Check if user has required privileges for scanning""" errors = [] - + try: # Check sudo access for oscap - stdin, stdout, stderr = ssh_client.exec_command('sudo -n oscap --version', timeout=10) + stdin, stdout, stderr = ssh_client.exec_command("sudo -n oscap --version", timeout=10) exit_status = stdout.channel.recv_exit_status() stderr_output = stderr.read().decode() - + if exit_status != 0: if "password is required" in stderr_output.lower(): - errors.append(ScanErrorInternal( - error_code="PRIV_001", - category=ErrorCategory.PRIVILEGE, - severity=ErrorSeverity.ERROR, - message="User lacks passwordless sudo access for OpenSCAP", - technical_details={"command": "sudo -n oscap --version", "stderr": stderr_output}, - user_guidance="Configure passwordless sudo for oscap command", - automated_fixes=[ - AutomatedFix( - fix_id="add_sudoers_oscap", - description="⚠️ SECURITY: Use secure automated fix system to configure sudo access", - command=None, # No direct command - use secure system - requires_sudo=True, - estimated_time=30, - is_safe=False # Mark as unsafe for direct execution - ) - ], - documentation_url="https://docs.openwatch.dev/troubleshooting/privileges#sudo-access" - )) - + errors.append( + ScanErrorInternal( + error_code="PRIV_001", + category=ErrorCategory.PRIVILEGE, + severity=ErrorSeverity.ERROR, + message="User lacks passwordless sudo access for OpenSCAP", + technical_details={ + "command": "sudo -n oscap --version", + "stderr": stderr_output, + }, + user_guidance="Configure passwordless sudo for oscap command", + automated_fixes=[ + AutomatedFix( + fix_id="add_sudoers_oscap", + description="⚠️ SECURITY: Use secure automated fix system to configure sudo access", + command=None, # No direct command - use secure system + requires_sudo=True, + estimated_time=30, + is_safe=False, # Mark as unsafe for direct execution + ) + ], + documentation_url="https://docs.openwatch.dev/troubleshooting/privileges#sudo-access", + ) + ) + # Check SELinux enforcement (if applicable) - stdin, stdout, stderr = ssh_client.exec_command('getenforce 2>/dev/null', timeout=5) + stdin, stdout, stderr = ssh_client.exec_command("getenforce 2>/dev/null", timeout=5) selinux_status = stdout.read().decode().strip().lower() - + if selinux_status == "enforcing": # Check OpenSCAP SELinux policies - stdin, stdout, stderr = ssh_client.exec_command('getsebool openscap_can_network 2>/dev/null', timeout=5) + stdin, stdout, stderr = ssh_client.exec_command( + "getsebool openscap_can_network 2>/dev/null", timeout=5 + ) bool_output = stdout.read().decode().strip() - + if "off" in bool_output: - errors.append(ScanErrorInternal( - error_code="PRIV_002", - category=ErrorCategory.PRIVILEGE, - severity=ErrorSeverity.WARNING, - message="SELinux blocking OpenSCAP network operations", - technical_details={"selinux_status": "enforcing", "openscap_can_network": "off"}, - user_guidance="Enable SELinux boolean to allow OpenSCAP network operations", - automated_fixes=[ - AutomatedFix( - fix_id="enable_selinux_openscap", - description="⚠️ SECURITY: Use secure automated fix system to configure SELinux", - command=None, # No direct command - use secure system - requires_sudo=True, - estimated_time=15, - is_safe=False # Mark as unsafe for direct execution - ) - ], - documentation_url="https://docs.openwatch.dev/troubleshooting/privileges#selinux" - )) - + errors.append( + ScanErrorInternal( + error_code="PRIV_002", + category=ErrorCategory.PRIVILEGE, + severity=ErrorSeverity.WARNING, + message="SELinux blocking OpenSCAP network operations", + technical_details={ + "selinux_status": "enforcing", + "openscap_can_network": "off", + }, + user_guidance="Enable SELinux boolean to allow OpenSCAP network operations", + automated_fixes=[ + AutomatedFix( + fix_id="enable_selinux_openscap", + description="⚠️ SECURITY: Use secure automated fix system to configure SELinux", + command=None, # No direct command - use secure system + requires_sudo=True, + estimated_time=15, + is_safe=False, # Mark as unsafe for direct execution + ) + ], + documentation_url="https://docs.openwatch.dev/troubleshooting/privileges#selinux", + ) + ) + except Exception as e: logger.warning(f"Privilege validation error (non-critical): {e}") # Don't add critical errors for privilege checks - they're warnings - + return errors class ResourceValidator: """System resource validation""" - + MIN_DISK_SPACE_MB = 500 MIN_MEMORY_MB = 512 - + @classmethod async def validate_resources(cls, ssh_client: paramiko.SSHClient) -> List[ScanErrorInternal]: """Check system resource availability""" errors = [] - + try: # Check disk space in /tmp - stdin, stdout, stderr = ssh_client.exec_command("df -BM /tmp | tail -1 | awk '{print $4}'", timeout=10) + stdin, stdout, stderr = ssh_client.exec_command( + "df -BM /tmp | tail -1 | awk '{print $4}'", timeout=10 + ) available_output = stdout.read().decode().strip() - + if available_output: try: - available_mb = int(available_output.rstrip('M')) + available_mb = int(available_output.rstrip("M")) if available_mb < cls.MIN_DISK_SPACE_MB: - errors.append(ScanErrorInternal( - error_code="RES_001", - category=ErrorCategory.RESOURCE, - severity=ErrorSeverity.ERROR, - message=f"Insufficient disk space: {available_mb}MB available in /tmp", - technical_details={"available_space_mb": available_mb, "required_space_mb": cls.MIN_DISK_SPACE_MB}, - user_guidance=f"Free up disk space in /tmp directory. Need at least {cls.MIN_DISK_SPACE_MB}MB for scan results.", - automated_fixes=[ - AutomatedFix( - fix_id="cleanup_tmp", - description="⚠️ SECURITY: Use secure automated fix system to clean up files", - command=None, # No direct command - use secure system - requires_sudo=True, - estimated_time=60, - is_safe=False # Mark as unsafe for direct execution - ) - ], - can_retry=True, - retry_after=300, - documentation_url="https://docs.openwatch.dev/troubleshooting/resources#disk-space" - )) + errors.append( + ScanErrorInternal( + error_code="RES_001", + category=ErrorCategory.RESOURCE, + severity=ErrorSeverity.ERROR, + message=f"Insufficient disk space: {available_mb}MB available in /tmp", + technical_details={ + "available_space_mb": available_mb, + "required_space_mb": cls.MIN_DISK_SPACE_MB, + }, + user_guidance=f"Free up disk space in /tmp directory. Need at least {cls.MIN_DISK_SPACE_MB}MB for scan results.", + automated_fixes=[ + AutomatedFix( + fix_id="cleanup_tmp", + description="⚠️ SECURITY: Use secure automated fix system to clean up files", + command=None, # No direct command - use secure system + requires_sudo=True, + estimated_time=60, + is_safe=False, # Mark as unsafe for direct execution + ) + ], + can_retry=True, + retry_after=300, + documentation_url="https://docs.openwatch.dev/troubleshooting/resources#disk-space", + ) + ) except ValueError: logger.warning(f"Could not parse disk space output: {available_output}") - + # Check memory availability - stdin, stdout, stderr = ssh_client.exec_command("free -m | grep '^Mem:' | awk '{print $7}'", timeout=10) + stdin, stdout, stderr = ssh_client.exec_command( + "free -m | grep '^Mem:' | awk '{print $7}'", timeout=10 + ) available_memory = stdout.read().decode().strip() - + if available_memory: try: available_mem_mb = int(available_memory) if available_mem_mb < cls.MIN_MEMORY_MB: - errors.append(ScanErrorInternal( - error_code="RES_002", - category=ErrorCategory.RESOURCE, - severity=ErrorSeverity.WARNING, - message=f"Low available memory: {available_mem_mb}MB", - technical_details={"available_memory_mb": available_mem_mb, "recommended_memory_mb": cls.MIN_MEMORY_MB}, - user_guidance="Available memory is low. Scan may run slower or fail. Consider stopping other processes.", - documentation_url="https://docs.openwatch.dev/troubleshooting/resources#memory" - )) + errors.append( + ScanErrorInternal( + error_code="RES_002", + category=ErrorCategory.RESOURCE, + severity=ErrorSeverity.WARNING, + message=f"Low available memory: {available_mem_mb}MB", + technical_details={ + "available_memory_mb": available_mem_mb, + "recommended_memory_mb": cls.MIN_MEMORY_MB, + }, + user_guidance="Available memory is low. Scan may run slower or fail. Consider stopping other processes.", + documentation_url="https://docs.openwatch.dev/troubleshooting/resources#memory", + ) + ) except ValueError: logger.warning(f"Could not parse memory output: {available_memory}") - + except Exception as e: logger.warning(f"Resource validation error (non-critical): {e}") - + return errors class DependencyValidator: """System dependency validation""" - + MIN_OPENSCAP_VERSION = "1.3.0" - + @classmethod async def validate_dependencies(cls, ssh_client: paramiko.SSHClient) -> List[ScanErrorInternal]: """Validate OpenSCAP installation and dependencies""" errors = [] - + try: # Check if OpenSCAP is installed - stdin, stdout, stderr = ssh_client.exec_command('which oscap', timeout=10) + stdin, stdout, stderr = ssh_client.exec_command("which oscap", timeout=10) oscap_path = stdout.read().decode().strip() - + if not oscap_path: - errors.append(ScanErrorInternal( - error_code="DEP_001", - category=ErrorCategory.DEPENDENCY, - severity=ErrorSeverity.ERROR, - message="OpenSCAP not installed on target system", - technical_details={"missing_command": "oscap"}, - user_guidance="Install OpenSCAP scanner package on the target system", - automated_fixes=[ - AutomatedFix( - fix_id="install_openscap_rhel", - description="⚠️ SECURITY: Use secure automated fix system to install OpenSCAP on RHEL/CentOS", - command=None, # No direct command - use secure system - requires_sudo=True, - estimated_time=120, - is_safe=False # Mark as unsafe for direct execution - ), - AutomatedFix( - fix_id="install_openscap_ubuntu", - description="⚠️ SECURITY: Use secure automated fix system to install OpenSCAP on Ubuntu/Debian", - command=None, # No direct command - use secure system - requires_sudo=True, - estimated_time=120, - is_safe=False # Mark as unsafe for direct execution - ) - ], - documentation_url="https://docs.openwatch.dev/troubleshooting/dependencies#openscap-installation" - )) + errors.append( + ScanErrorInternal( + error_code="DEP_001", + category=ErrorCategory.DEPENDENCY, + severity=ErrorSeverity.ERROR, + message="OpenSCAP not installed on target system", + technical_details={"missing_command": "oscap"}, + user_guidance="Install OpenSCAP scanner package on the target system", + automated_fixes=[ + AutomatedFix( + fix_id="install_openscap_rhel", + description="⚠️ SECURITY: Use secure automated fix system to install OpenSCAP on RHEL/CentOS", + command=None, # No direct command - use secure system + requires_sudo=True, + estimated_time=120, + is_safe=False, # Mark as unsafe for direct execution + ), + AutomatedFix( + fix_id="install_openscap_ubuntu", + description="⚠️ SECURITY: Use secure automated fix system to install OpenSCAP on Ubuntu/Debian", + command=None, # No direct command - use secure system + requires_sudo=True, + estimated_time=120, + is_safe=False, # Mark as unsafe for direct execution + ), + ], + documentation_url="https://docs.openwatch.dev/troubleshooting/dependencies#openscap-installation", + ) + ) return errors - + # Check OpenSCAP version - stdin, stdout, stderr = ssh_client.exec_command('oscap --version', timeout=10) + stdin, stdout, stderr = ssh_client.exec_command("oscap --version", timeout=10) version_output = stdout.read().decode() - + version = cls._parse_openscap_version(version_output) if version and cls._version_compare(version, cls.MIN_OPENSCAP_VERSION) < 0: - errors.append(ScanErrorInternal( - error_code="DEP_002", - category=ErrorCategory.DEPENDENCY, - severity=ErrorSeverity.WARNING, - message=f"OpenSCAP version {version} installed, recommended >= {cls.MIN_OPENSCAP_VERSION}", - technical_details={"current_version": version, "minimum_version": cls.MIN_OPENSCAP_VERSION}, - user_guidance="Update OpenSCAP to latest version for best compatibility", - automated_fixes=[ - AutomatedFix( - fix_id="update_openscap", - description="⚠️ SECURITY: Use secure automated fix system to update OpenSCAP", - command=None, # No direct command - use secure system - requires_sudo=True, - estimated_time=60, - is_safe=False # Mark as unsafe for direct execution - ) - ], - documentation_url="https://docs.openwatch.dev/troubleshooting/dependencies#version-upgrade" - )) - + errors.append( + ScanErrorInternal( + error_code="DEP_002", + category=ErrorCategory.DEPENDENCY, + severity=ErrorSeverity.WARNING, + message=f"OpenSCAP version {version} installed, recommended >= {cls.MIN_OPENSCAP_VERSION}", + technical_details={ + "current_version": version, + "minimum_version": cls.MIN_OPENSCAP_VERSION, + }, + user_guidance="Update OpenSCAP to latest version for best compatibility", + automated_fixes=[ + AutomatedFix( + fix_id="update_openscap", + description="⚠️ SECURITY: Use secure automated fix system to update OpenSCAP", + command=None, # No direct command - use secure system + requires_sudo=True, + estimated_time=60, + is_safe=False, # Mark as unsafe for direct execution + ) + ], + documentation_url="https://docs.openwatch.dev/troubleshooting/dependencies#version-upgrade", + ) + ) + except Exception as e: - errors.append(ScanErrorInternal( - error_code="DEP_999", - category=ErrorCategory.DEPENDENCY, - severity=ErrorSeverity.ERROR, - message="Failed to validate system dependencies", - technical_details={"error": str(e)}, - user_guidance="Could not check system dependencies. Ensure SSH access is working properly.", - can_retry=True - )) - + errors.append( + ScanErrorInternal( + error_code="DEP_999", + category=ErrorCategory.DEPENDENCY, + severity=ErrorSeverity.ERROR, + message="Failed to validate system dependencies", + technical_details={"error": str(e)}, + user_guidance="Could not check system dependencies. Ensure SSH access is working properly.", + can_retry=True, + ) + ) + return errors - + @staticmethod def _parse_openscap_version(version_output: str) -> Optional[str]: """Extract version number from oscap --version output""" import re - match = re.search(r'(\d+\.\d+\.\d+)', version_output) + + match = re.search(r"(\d+\.\d+\.\d+)", version_output) return match.group(1) if match else None - + @staticmethod def _version_compare(version1: str, version2: str) -> int: """Compare two version strings. Returns -1, 0, or 1""" + def version_tuple(v): - return tuple(map(int, v.split('.'))) - + return tuple(map(int, v.split("."))) + v1_tuple = version_tuple(version1) v2_tuple = version_tuple(version2) - + return (v1_tuple > v2_tuple) - (v1_tuple < v2_tuple) # Stub classes needed for credential validation class SecurityContext(BaseModel): """Security context for error classification""" + hostname: str = "" username: str = "" auth_method: str = "" @@ -571,27 +651,31 @@ def classify_authentication_error(context: SecurityContext) -> ScanErrorInternal severity=ErrorSeverity.ERROR, message="Authentication error occurred", technical_details={"context": context.dict()}, - user_guidance="Please check your authentication credentials and try again." + user_guidance="Please check your authentication credentials and try again.", ) class ErrorClassificationService: """Main error classification service""" - + def __init__(self): self.network_validator = NetworkValidator() self.auth_validator = AuthenticationValidator() self.privilege_validator = PrivilegeValidator() self.resource_validator = ResourceValidator() self.dependency_validator = DependencyValidator() - - async def classify_error(self, error: Exception, context: Dict[str, Any] = None) -> ScanErrorInternal: + + async def classify_error( + self, error: Exception, context: Dict[str, Any] = None + ) -> ScanErrorInternal: """Classify and enhance a generic error with actionable guidance""" context = context or {} error_str = str(error).lower() - + # Network errors - if any(keyword in error_str for keyword in ['connection refused', 'timeout', 'unreachable']): + if any( + keyword in error_str for keyword in ["connection refused", "timeout", "unreachable"] + ): return ScanErrorInternal( error_code="NET_006", category=ErrorCategory.NETWORK, @@ -600,22 +684,25 @@ async def classify_error(self, error: Exception, context: Dict[str, Any] = None) technical_details={"original_error": str(error), "context": context}, user_guidance="Check network connectivity and ensure target host is reachable", can_retry=True, - retry_after=60 + retry_after=60, ) - + # Authentication errors - if any(keyword in error_str for keyword in ['permission denied', 'authentication failed', 'invalid credentials']): + if any( + keyword in error_str + for keyword in ["permission denied", "authentication failed", "invalid credentials"] + ): return ScanErrorInternal( error_code="AUTH_005", category=ErrorCategory.AUTHENTICATION, severity=ErrorSeverity.ERROR, message=f"Authentication failed: {str(error)}", technical_details={"original_error": str(error), "context": context}, - user_guidance="Verify username and credentials are correct and have proper access" + user_guidance="Verify username and credentials are correct and have proper access", ) - + # Resource errors - if any(keyword in error_str for keyword in ['no space', 'disk full', 'out of memory']): + if any(keyword in error_str for keyword in ["no space", "disk full", "out of memory"]): return ScanErrorInternal( error_code="RES_003", category=ErrorCategory.RESOURCE, @@ -624,9 +711,9 @@ async def classify_error(self, error: Exception, context: Dict[str, Any] = None) technical_details={"original_error": str(error), "context": context}, user_guidance="Free up system resources (disk space, memory) and try again", can_retry=True, - retry_after=300 + retry_after=300, ) - + # Default to execution error return ScanErrorInternal( error_code="EXEC_001", @@ -635,29 +722,41 @@ async def classify_error(self, error: Exception, context: Dict[str, Any] = None) message=f"Scan execution failed: {str(error)}", technical_details={"original_error": str(error), "context": context}, user_guidance="An unexpected error occurred during scan execution. Check logs for more details.", - can_retry=True + can_retry=True, ) - - async def validate_scan_prerequisites(self, hostname: str, port: int, username: str, - auth_method: str, credential: str, - user_id: Optional[str] = None, - source_ip: Optional[str] = None) -> ValidationResultInternal: + + async def validate_scan_prerequisites( + self, + hostname: str, + port: int, + username: str, + auth_method: str, + credential: str, + user_id: Optional[str] = None, + source_ip: Optional[str] = None, + ) -> ValidationResultInternal: """Comprehensive pre-flight validation""" start_time = datetime.utcnow() errors = [] warnings = [] system_info = {} validation_checks = {} - + logger.info(f"Starting pre-flight validation for {username}@{hostname}:{port}") - + # Stage 1: Network Connectivity try: network_errors = await self.network_validator.validate_connectivity(hostname, port) validation_checks["network_connectivity"] = len(network_errors) == 0 - errors.extend([e for e in network_errors if e.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL]]) + errors.extend( + [ + e + for e in network_errors + if e.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL] + ] + ) warnings.extend([e for e in network_errors if e.severity == ErrorSeverity.WARNING]) - + if errors: # Can't proceed if network fails duration = (datetime.utcnow() - start_time).total_seconds() return ValidationResultInternal( @@ -665,20 +764,28 @@ async def validate_scan_prerequisites(self, hostname: str, port: int, username: errors=errors, warnings=warnings, pre_flight_duration=duration, - validation_checks=validation_checks + validation_checks=validation_checks, ) except Exception as e: logger.error(f"Network validation failed: {e}") validation_checks["network_connectivity"] = False errors.append(await self.classify_error(e, {"stage": "network_validation"})) - - # Stage 2: Authentication + + # Stage 2: Authentication try: - auth_errors = await self.auth_validator.validate_credentials(hostname, port, username, auth_method, credential) + auth_errors = await self.auth_validator.validate_credentials( + hostname, port, username, auth_method, credential + ) validation_checks["authentication"] = len(auth_errors) == 0 - errors.extend([e for e in auth_errors if e.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL]]) + errors.extend( + [ + e + for e in auth_errors + if e.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL] + ] + ) warnings.extend([e for e in auth_errors if e.severity == ErrorSeverity.WARNING]) - + if errors: # Can't proceed if auth fails duration = (datetime.utcnow() - start_time).total_seconds() return ValidationResultInternal( @@ -686,14 +793,14 @@ async def validate_scan_prerequisites(self, hostname: str, port: int, username: errors=errors, warnings=warnings, pre_flight_duration=duration, - validation_checks=validation_checks + validation_checks=validation_checks, ) - + except Exception as e: logger.error(f"Authentication validation failed: {e}") validation_checks["authentication"] = False errors.append(await self.classify_error(e, {"stage": "authentication_validation"})) - + # Stage 3: Advanced validations (if we can connect) ssh_client = None try: @@ -703,63 +810,99 @@ async def validate_scan_prerequisites(self, hostname: str, port: int, username: # Load system and user host keys for validation try: ssh_client.load_system_host_keys() - ssh_client.load_host_keys(os.path.expanduser('~/.ssh/known_hosts')) + ssh_client.load_host_keys(os.path.expanduser("~/.ssh/known_hosts")) except FileNotFoundError: - logger.warning("No known_hosts files found - SSH connections may fail without proper host key management") - + logger.warning( + "No known_hosts files found - SSH connections may fail without proper host key management" + ) + if auth_method == "password": - ssh_client.connect(hostname, port=port, username=username, password=credential, timeout=10) + ssh_client.connect( + hostname, port=port, username=username, password=credential, timeout=10 + ) else: from .ssh_utils import parse_ssh_key + key = parse_ssh_key(credential) ssh_client.connect(hostname, port=port, username=username, pkey=key, timeout=10) - + # Get system information (will be sanitized later) - stdin, stdout, stderr = ssh_client.exec_command('uname -a && cat /etc/os-release 2>/dev/null || echo "OS info not available"', timeout=10) + stdin, stdout, stderr = ssh_client.exec_command( + 'uname -a && cat /etc/os-release 2>/dev/null || echo "OS info not available"', + timeout=10, + ) system_info_output = stdout.read().decode() - + # Store raw system info for sanitization system_info["system_details"] = system_info_output.strip() system_info["collection_timestamp"] = datetime.utcnow().isoformat() - + # Add additional system information safely # Check OpenSCAP availability for compliance - stdin, stdout, stderr = ssh_client.exec_command('which oscap', timeout=5) + stdin, stdout, stderr = ssh_client.exec_command("which oscap", timeout=5) oscap_path = stdout.read().decode().strip() system_info["openscap_available"] = bool(oscap_path) - + # Check SSH availability (we're already connected, so it's available) system_info["ssh_available"] = True - + # Check basic resource info (will be sanitized) - stdin, stdout, stderr = ssh_client.exec_command('df /tmp | tail -1 | awk \'{print $4}\'', timeout=5) + stdin, stdout, stderr = ssh_client.exec_command( + "df /tmp | tail -1 | awk '{print $4}'", timeout=5 + ) disk_output = stdout.read().decode().strip() - if disk_output and disk_output.rstrip('M').isdigit(): - system_info["disk_space"] = int(disk_output.rstrip('M')) - - stdin, stdout, stderr = ssh_client.exec_command('free -m | grep "^Mem:" | awk \'{print $7}\'', timeout=5) + if disk_output and disk_output.rstrip("M").isdigit(): + system_info["disk_space"] = int(disk_output.rstrip("M")) + + stdin, stdout, stderr = ssh_client.exec_command( + "free -m | grep \"^Mem:\" | awk '{print $7}'", timeout=5 + ) memory_output = stdout.read().decode().strip() if memory_output and memory_output.isdigit(): system_info["memory"] = int(memory_output) - + # Privilege validation privilege_errors = await self.privilege_validator.validate_privileges(ssh_client) - validation_checks["privileges"] = len([e for e in privilege_errors if e.severity == ErrorSeverity.ERROR]) == 0 - errors.extend([e for e in privilege_errors if e.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL]]) + validation_checks["privileges"] = ( + len([e for e in privilege_errors if e.severity == ErrorSeverity.ERROR]) == 0 + ) + errors.extend( + [ + e + for e in privilege_errors + if e.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL] + ] + ) warnings.extend([e for e in privilege_errors if e.severity == ErrorSeverity.WARNING]) - + # Resource validation resource_errors = await self.resource_validator.validate_resources(ssh_client) - validation_checks["resources"] = len([e for e in resource_errors if e.severity == ErrorSeverity.ERROR]) == 0 - errors.extend([e for e in resource_errors if e.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL]]) + validation_checks["resources"] = ( + len([e for e in resource_errors if e.severity == ErrorSeverity.ERROR]) == 0 + ) + errors.extend( + [ + e + for e in resource_errors + if e.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL] + ] + ) warnings.extend([e for e in resource_errors if e.severity == ErrorSeverity.WARNING]) - + # Dependency validation dependency_errors = await self.dependency_validator.validate_dependencies(ssh_client) - validation_checks["dependencies"] = len([e for e in dependency_errors if e.severity == ErrorSeverity.ERROR]) == 0 - errors.extend([e for e in dependency_errors if e.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL]]) + validation_checks["dependencies"] = ( + len([e for e in dependency_errors if e.severity == ErrorSeverity.ERROR]) == 0 + ) + errors.extend( + [ + e + for e in dependency_errors + if e.severity in [ErrorSeverity.ERROR, ErrorSeverity.CRITICAL] + ] + ) warnings.extend([e for e in dependency_errors if e.severity == ErrorSeverity.WARNING]) - + except Exception as e: logger.error(f"Advanced validation failed: {e}") # Don't fail completely - basic connectivity/auth worked @@ -767,12 +910,14 @@ async def validate_scan_prerequisites(self, hostname: str, port: int, username: finally: if ssh_client: ssh_client.close() - + duration = (datetime.utcnow() - start_time).total_seconds() can_proceed = len(errors) == 0 - - logger.info(f"Pre-flight validation completed in {duration:.2f}s: can_proceed={can_proceed}, errors={len(errors)}, warnings={len(warnings)}") - + + logger.info( + f"Pre-flight validation completed in {duration:.2f}s: can_proceed={can_proceed}, errors={len(errors)}, warnings={len(warnings)}" + ) + # Log the internal validation result for audit (contains sensitive data) if errors or warnings: for error in errors + warnings: @@ -780,14 +925,14 @@ async def validate_scan_prerequisites(self, hostname: str, port: int, username: error_code=error.error_code, technical_details=error.technical_details, sanitized_response={ - 'error_code': error.error_code, - 'category': error.category.value, - 'severity': error.severity.value, - 'can_retry': error.can_retry + "error_code": error.error_code, + "category": error.category.value, + "severity": error.severity.value, + "can_retry": error.can_retry, }, user_id=user_id, source_ip=source_ip, - severity=error.severity + severity=error.severity, ) return ValidationResultInternal( @@ -796,29 +941,27 @@ async def validate_scan_prerequisites(self, hostname: str, port: int, username: warnings=warnings, pre_flight_duration=duration, system_info=system_info, - validation_checks=validation_checks + validation_checks=validation_checks, ) - + def get_sanitized_validation_result( - self, + self, internal_result: ValidationResultInternal, user_id: Optional[str] = None, source_ip: Optional[str] = None, user_role: Optional[str] = None, - is_admin: bool = False + is_admin: bool = False, ) -> ValidationResultResponse: """ Convert internal validation result to sanitized user response. This integrates with Security Fix 5 system information sanitization. """ - + # Sanitize errors using existing error sanitization sanitized_errors = [] for error in internal_result.errors: sanitized_error = sanitization_service.sanitize_error( - error.dict(), - user_id=user_id, - source_ip=source_ip + error.dict(), user_id=user_id, source_ip=source_ip ) # Convert SanitizedError to ScanErrorResponse scan_error_response = ScanErrorResponse( @@ -830,17 +973,15 @@ def get_sanitized_validation_result( can_retry=sanitized_error.can_retry, retry_after=sanitized_error.retry_after, documentation_url=sanitized_error.documentation_url, - timestamp=sanitized_error.timestamp + timestamp=sanitized_error.timestamp, ) sanitized_errors.append(scan_error_response) - + # Sanitize warnings using existing error sanitization - sanitized_warnings = [] + sanitized_warnings = [] for warning in internal_result.warnings: sanitized_warning = sanitization_service.sanitize_error( - warning.dict(), - user_id=user_id, - source_ip=source_ip + warning.dict(), user_id=user_id, source_ip=source_ip ) # Convert SanitizedError to ScanErrorResponse scan_warning_response = ScanErrorResponse( @@ -852,10 +993,10 @@ def get_sanitized_validation_result( can_retry=sanitized_warning.can_retry, retry_after=sanitized_warning.retry_after, documentation_url=sanitized_warning.documentation_url, - timestamp=sanitized_warning.timestamp + timestamp=sanitized_warning.timestamp, ) sanitized_warnings.append(scan_warning_response) - + # Sanitize system information using Security Fix 5 integration sanitized_system_info = {} if internal_result.system_info: @@ -864,14 +1005,14 @@ def get_sanitized_validation_result( user_role=user_role, is_admin=is_admin, user_id=user_id, - source_ip=source_ip + source_ip=source_ip, ) - + return ValidationResultResponse( can_proceed=internal_result.can_proceed, errors=sanitized_errors, warnings=sanitized_warnings, pre_flight_duration=internal_result.pre_flight_duration, validation_checks=internal_result.validation_checks, - system_info=sanitized_system_info # Now includes sanitized system info - ) \ No newline at end of file + system_info=sanitized_system_info, # Now includes sanitized system info + ) diff --git a/backend/app/services/error_sanitization.py b/backend/app/services/error_sanitization.py index 67aff992..2e89c21c 100644 --- a/backend/app/services/error_sanitization.py +++ b/backend/app/services/error_sanitization.py @@ -2,6 +2,7 @@ OpenWatch Error Response Sanitization Service Removes sensitive information from error responses while maintaining actionable user guidance """ + import re import hashlib import logging @@ -16,25 +17,31 @@ # Will be imported later to avoid circular imports _system_info_sanitization_service = None + class SanitizationLevel(str, Enum): """Levels of error information sanitization""" - MINIMAL = "minimal" # Remove only critical PII - STANDARD = "standard" # Remove all sensitive technical details - STRICT = "strict" # Remove all technical information + + MINIMAL = "minimal" # Remove only critical PII + STANDARD = "standard" # Remove all sensitive technical details + STRICT = "strict" # Remove all technical information + class AuditLogEntry(BaseModel): """Audit log entry for security events""" + timestamp: datetime = Field(default_factory=datetime.utcnow) event_type: str error_code: str user_id: Optional[str] = None - source_ip: Optional[str] = None + source_ip: Optional[str] = None technical_details: Dict[str, Any] sanitized_response: Dict[str, Any] severity: str + class SanitizedError(BaseModel): """User-safe error response model""" + error_code: str category: str severity: str @@ -46,8 +53,10 @@ class SanitizedError(BaseModel): timestamp: datetime = Field(default_factory=datetime.utcnow) # No technical_details field - removed for security + class RateLimitState(BaseModel): """Rate limiting state for error endpoint access""" + ip_address: str error_count: int = 0 first_error_time: datetime = Field(default_factory=datetime.utcnow) @@ -55,177 +64,170 @@ class RateLimitState(BaseModel): is_blocked: bool = False block_until: Optional[datetime] = None + class ErrorSanitizationService: """Service to sanitize error responses and prevent information disclosure""" - + # Rate limiting configuration MAX_ERRORS_PER_HOUR = 50 MAX_ERRORS_PER_MINUTE = 10 BLOCK_DURATION_MINUTES = 60 - + # Sensitive information patterns to remove SENSITIVE_PATTERNS = [ # Usernames and hostnames r'\b(username|user|login)\s*[:=]\s*["\']?([^"\':\s]+)["\']?', r'\b(hostname|host|server)\s*[:=]\s*["\']?([^"\':\s]+)["\']?', - # SSH details r'publickey authentication failed for user\s+["\']?([^"\':\s]+)["\']?', r'SSH authentication failed:.*?for user\s+["\']?([^"\':\s]+)["\']?', - r'ssh_exchange_identification:\s+.*', - + r"ssh_exchange_identification:\s+.*", # System information - r'(Linux|Windows|Darwin)\s+[\w\-\.]+\s+[\d\.]+', - r'/[a-zA-Z0-9_\-/\.]+\.(sh|py|conf|cfg|ini|xml)', - r'[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}', - - # Configuration details - r'port\s+\d+', - r'timeout\s+\d+', - r'connection\s+result\s+\d+', - + r"(Linux|Windows|Darwin)\s+[\w\-\.]+\s+[\d\.]+", + r"/[a-zA-Z0-9_\-/\.]+\.(sh|py|conf|cfg|ini|xml)", + r"[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}", + # Configuration details + r"port\s+\d+", + r"timeout\s+\d+", + r"connection\s+result\s+\d+", # OS release information r'VERSION_ID\s*=\s*["\'][^"\']+["\']', r'PRETTY_NAME\s*=\s*["\'][^"\']+["\']', r'NAME\s*=\s*["\'][^"\']+["\']', - # Error details that leak information - r'stderr:\s*.*', - r'command:\s*.*', - r'banner:\s*.*' + r"stderr:\s*.*", + r"command:\s*.*", + r"banner:\s*.*", ] - + # Generic error messages mapped by error code patterns GENERIC_MESSAGES = { # Network errors - 'NET_001': "Unable to resolve the target host address", - 'NET_002': "Cannot establish connection to target host", - 'NET_003': "Connection to target host timed out", - 'NET_004': "Unexpected service detected on target port", - 'NET_005': "Remote service not responding properly", - 'NET_006': "Network connectivity issue detected", - + "NET_001": "Unable to resolve the target host address", + "NET_002": "Cannot establish connection to target host", + "NET_003": "Connection to target host timed out", + "NET_004": "Unexpected service detected on target port", + "NET_005": "Remote service not responding properly", + "NET_006": "Network connectivity issue detected", # Authentication errors - 'AUTH_001': "Account is temporarily locked", - 'AUTH_002': "Authentication credentials are invalid", - 'AUTH_003': "SSH key format is invalid", - 'AUTH_004': "SSH key authentication failed", - 'AUTH_005': "Authentication failed", - + "AUTH_001": "Account is temporarily locked", + "AUTH_002": "Authentication credentials are invalid", + "AUTH_003": "SSH key format is invalid", + "AUTH_004": "SSH key authentication failed", + "AUTH_005": "Authentication failed", # Privilege errors - 'PRIV_001': "Insufficient privileges for scan operations", - 'PRIV_002': "Security policy blocking scan operations", - + "PRIV_001": "Insufficient privileges for scan operations", + "PRIV_002": "Security policy blocking scan operations", # Resource errors - 'RES_001': "Insufficient disk space available", - 'RES_002': "Insufficient memory available", - 'RES_003': "System resource constraint detected", - + "RES_001": "Insufficient disk space available", + "RES_002": "Insufficient memory available", + "RES_003": "System resource constraint detected", # Dependency errors - 'DEP_001': "Required scanner software not found", - 'DEP_002': "Scanner software version incompatible", - 'DEP_999': "System dependency validation failed", - + "DEP_001": "Required scanner software not found", + "DEP_002": "Scanner software version incompatible", + "DEP_999": "System dependency validation failed", # Execution errors - 'EXEC_001': "Scan execution failed due to unexpected error" + "EXEC_001": "Scan execution failed due to unexpected error", } - + def __init__(self): self.rate_limit_cache: Dict[str, RateLimitState] = {} self._cleanup_rate_limit_cache() - + def sanitize_error( - self, - error_data: Dict[str, Any], + self, + error_data: Dict[str, Any], sanitization_level: SanitizationLevel = SanitizationLevel.STANDARD, user_id: Optional[str] = None, - source_ip: Optional[str] = None + source_ip: Optional[str] = None, ) -> SanitizedError: """ Sanitize error response by removing sensitive information - + Args: error_data: Original error data from ErrorClassificationService sanitization_level: Level of sanitization to apply user_id: User ID for audit logging source_ip: Source IP for rate limiting - + Returns: SanitizedError: Clean error response safe for users """ - + # Check rate limiting first if source_ip and self._is_rate_limited(source_ip): logger.warning(f"Rate limited error request from IP: {source_ip}") return self._create_rate_limit_error() - + # Log full technical details for audit self._log_security_event(error_data, user_id, source_ip) - + # Extract error code and get generic message - error_code = error_data.get('error_code', 'UNKNOWN') - generic_message = self.GENERIC_MESSAGES.get(error_code, "An error occurred during the operation") - + error_code = error_data.get("error_code", "UNKNOWN") + generic_message = self.GENERIC_MESSAGES.get( + error_code, "An error occurred during the operation" + ) + # Create sanitized error response sanitized = SanitizedError( error_code=error_code, - category=error_data.get('category', 'execution'), - severity=error_data.get('severity', 'error'), + category=error_data.get("category", "execution"), + severity=error_data.get("severity", "error"), message=generic_message, - user_guidance=self._sanitize_guidance(error_data.get('user_guidance', '')), - can_retry=error_data.get('can_retry', False), - retry_after=error_data.get('retry_after'), - documentation_url=error_data.get('documentation_url', '') + user_guidance=self._sanitize_guidance(error_data.get("user_guidance", "")), + can_retry=error_data.get("can_retry", False), + retry_after=error_data.get("retry_after"), + documentation_url=error_data.get("documentation_url", ""), ) - + # Update rate limiting if source_ip: self._update_rate_limit(source_ip) - + return sanitized - + def _sanitize_guidance(self, guidance: str) -> str: """Remove sensitive information from user guidance text""" sanitized = guidance - + for pattern in self.SENSITIVE_PATTERNS: - sanitized = re.sub(pattern, '[REDACTED]', sanitized, flags=re.IGNORECASE) - + sanitized = re.sub(pattern, "[REDACTED]", sanitized, flags=re.IGNORECASE) + # Remove specific technical commands - sanitized = re.sub(r'`[^`]+`', '[COMMAND_REDACTED]', sanitized) - + sanitized = re.sub(r"`[^`]+`", "[COMMAND_REDACTED]", sanitized) + # Remove file paths - sanitized = re.sub(r'/[/\w\-\.]+', '[PATH_REDACTED]', sanitized) - + sanitized = re.sub(r"/[/\w\-\.]+", "[PATH_REDACTED]", sanitized) + # Keep guidance actionable but generic sanitized = self._make_guidance_generic(sanitized) - + return sanitized - + def _make_guidance_generic(self, guidance: str) -> str: """Convert specific guidance to generic actionable advice""" generic_replacements = { - 'Check if SSH service is running on port [REDACTED]': 'Verify SSH service is running on the correct port', - 'Verify the hostname [REDACTED] is correct': 'Verify the target hostname is correct', - 'Add public key to [PATH_REDACTED] on target host': 'Ensure SSH public key is authorized on target host', - 'Configure passwordless sudo for [COMMAND_REDACTED]': 'Configure appropriate system privileges for scanning', - 'Free up disk space in [PATH_REDACTED]': 'Free up sufficient disk space on target system', - 'Check [COMMAND_REDACTED] service status': 'Check required service status on target system' + "Check if SSH service is running on port [REDACTED]": "Verify SSH service is running on the correct port", + "Verify the hostname [REDACTED] is correct": "Verify the target hostname is correct", + "Add public key to [PATH_REDACTED] on target host": "Ensure SSH public key is authorized on target host", + "Configure passwordless sudo for [COMMAND_REDACTED]": "Configure appropriate system privileges for scanning", + "Free up disk space in [PATH_REDACTED]": "Free up sufficient disk space on target system", + "Check [COMMAND_REDACTED] service status": "Check required service status on target system", } - + result = guidance for specific, generic in generic_replacements.items(): result = result.replace(specific, generic) - + return result - + def _is_rate_limited(self, source_ip: str) -> bool: """Check if source IP is rate limited""" if source_ip not in self.rate_limit_cache: return False - + state = self.rate_limit_cache[source_ip] - + # Check if currently blocked if state.is_blocked and state.block_until: if datetime.utcnow() < state.block_until: @@ -235,37 +237,37 @@ def _is_rate_limited(self, source_ip: str) -> bool: state.is_blocked = False state.block_until = None state.error_count = 0 - + return False - + def _update_rate_limit(self, source_ip: str): """Update rate limiting state for source IP""" now = datetime.utcnow() - + if source_ip not in self.rate_limit_cache: self.rate_limit_cache[source_ip] = RateLimitState(ip_address=source_ip) - + state = self.rate_limit_cache[source_ip] state.error_count += 1 state.last_error_time = now - + # Check rate limits time_diff_minutes = (now - state.first_error_time).total_seconds() / 60 time_diff_seconds = (now - state.first_error_time).total_seconds() - + # Reset counter if more than 1 hour has passed if time_diff_minutes > 60: state.error_count = 1 state.first_error_time = now - + # Check per-minute limit elif time_diff_seconds < 60 and state.error_count > self.MAX_ERRORS_PER_MINUTE: self._block_ip(source_ip) - - # Check per-hour limit + + # Check per-hour limit elif time_diff_minutes <= 60 and state.error_count > self.MAX_ERRORS_PER_HOUR: self._block_ip(source_ip) - + def _block_ip(self, source_ip: str): """Block IP address due to rate limit violation""" if source_ip in self.rate_limit_cache: @@ -274,86 +276,81 @@ def _block_ip(self, source_ip: str): state.block_until = datetime.utcnow().replace( minute=datetime.utcnow().minute + self.BLOCK_DURATION_MINUTES ) - - logger.warning(f"IP {source_ip} blocked for {self.BLOCK_DURATION_MINUTES} minutes due to rate limiting") - + + logger.warning( + f"IP {source_ip} blocked for {self.BLOCK_DURATION_MINUTES} minutes due to rate limiting" + ) + def _create_rate_limit_error(self) -> SanitizedError: """Create error response for rate-limited requests""" return SanitizedError( error_code="RATE_LIMIT", category="security", - severity="error", + severity="error", message="Too many error requests detected", user_guidance="Please wait before retrying. Contact support if this continues.", can_retry=True, retry_after=self.BLOCK_DURATION_MINUTES * 60, # Convert to seconds - documentation_url="https://docs.openwatch.dev/security/rate-limits" + documentation_url="https://docs.openwatch.dev/security/rate-limits", ) - + def _log_security_event( - self, - error_data: Dict[str, Any], - user_id: Optional[str], - source_ip: Optional[str] + self, error_data: Dict[str, Any], user_id: Optional[str], source_ip: Optional[str] ): """Log full error details for security audit""" - + # Create audit log entry audit_entry = AuditLogEntry( event_type="error_classification", - error_code=error_data.get('error_code', 'UNKNOWN'), + error_code=error_data.get("error_code", "UNKNOWN"), user_id=user_id, source_ip=source_ip, - technical_details=error_data.get('technical_details', {}), + technical_details=error_data.get("technical_details", {}), sanitized_response={ - 'error_code': error_data.get('error_code', 'UNKNOWN'), - 'category': error_data.get('category', 'execution'), - 'severity': error_data.get('severity', 'error'), - 'message_pattern': self.GENERIC_MESSAGES.get( - error_data.get('error_code', 'UNKNOWN'), - 'Generic error message' - ) + "error_code": error_data.get("error_code", "UNKNOWN"), + "category": error_data.get("category", "execution"), + "severity": error_data.get("severity", "error"), + "message_pattern": self.GENERIC_MESSAGES.get( + error_data.get("error_code", "UNKNOWN"), "Generic error message" + ), }, - severity=error_data.get('severity', 'error') + severity=error_data.get("severity", "error"), ) - + # Log to security audit log - security_logger = logging.getLogger('security_audit') + security_logger = logging.getLogger("security_audit") security_logger.info( f"Error Classification Event: {audit_entry.json()}", extra={ - 'event_type': 'error_classification', - 'error_code': audit_entry.error_code, - 'user_id': user_id, - 'source_ip': source_ip, - 'severity': audit_entry.severity - } + "event_type": "error_classification", + "error_code": audit_entry.error_code, + "user_id": user_id, + "source_ip": source_ip, + "severity": audit_entry.severity, + }, ) - + # Also log summary to main logger logger.info(f"Sanitized error response for {audit_entry.error_code} from IP {source_ip}") - + def _cleanup_rate_limit_cache(self): """Clean up expired rate limit entries""" now = datetime.utcnow() expired_ips = [] - + for ip, state in self.rate_limit_cache.items(): # Remove entries older than 2 hours time_diff_hours = (now - state.first_error_time).total_seconds() / 3600 if time_diff_hours > 2: expired_ips.append(ip) - + for ip in expired_ips: del self.rate_limit_cache[ip] - + logger.debug(f"Cleaned up {len(expired_ips)} expired rate limit entries") - + def _sanitize_system_info_integration( - self, - system_info: Dict[str, Any], - user_id: Optional[str], - source_ip: Optional[str] + self, system_info: Dict[str, Any], user_id: Optional[str], source_ip: Optional[str] ) -> Dict[str, Any]: """ Integrate with system information sanitization service from Security Fix 5. @@ -364,48 +361,50 @@ def _sanitize_system_info_integration( global _system_info_sanitization_service if _system_info_sanitization_service is None: from .system_info_sanitization import get_system_info_sanitization_service + _system_info_sanitization_service = get_system_info_sanitization_service() - + # Create sanitization context from ..models.system_models import SystemInfoSanitizationContext, SystemInfoLevel + context = SystemInfoSanitizationContext( user_id=user_id, source_ip=source_ip, access_level=SystemInfoLevel.BASIC, # Default to basic for error contexts is_admin=False, # Conservative default - compliance_only=True + compliance_only=True, ) - + # Apply integrated sanitization - sanitized_info, metadata = _system_info_sanitization_service.sanitize_system_information( - system_info, context + sanitized_info, metadata = ( + _system_info_sanitization_service.sanitize_system_information(system_info, context) ) - + # Only keep safe metadata return { - 'validation_timestamp': sanitized_info.get('validation_timestamp'), - 'system_compatible': sanitized_info.get('system_compatible', True), - 'sanitization_applied': True, - 'access_level': 'basic' + "validation_timestamp": sanitized_info.get("validation_timestamp"), + "system_compatible": sanitized_info.get("system_compatible", True), + "sanitization_applied": True, + "access_level": "basic", } - + except Exception as e: logger.error(f"System info sanitization integration failed: {e}") # Fallback to basic safe info return { - 'validation_timestamp': system_info.get('validation_timestamp'), - 'sanitization_applied': True, - 'access_level': 'basic', - 'error_fallback': True + "validation_timestamp": system_info.get("validation_timestamp"), + "sanitization_applied": True, + "access_level": "basic", + "error_fallback": True, } - + def sanitize_system_info_context( self, system_info: Dict[str, Any], user_role: Optional[str] = None, is_admin: bool = False, user_id: Optional[str] = None, - source_ip: Optional[str] = None + source_ip: Optional[str] = None, ) -> Dict[str, Any]: """ Enhanced system information sanitization with role-based access. @@ -416,107 +415,112 @@ def sanitize_system_info_context( global _system_info_sanitization_service if _system_info_sanitization_service is None: from .system_info_sanitization import get_system_info_sanitization_service + _system_info_sanitization_service = get_system_info_sanitization_service() - + # Create comprehensive sanitization context from ..models.system_models import SystemInfoSanitizationContext, SystemInfoLevel - + # Determine access level based on user role access_level = SystemInfoLevel.BASIC - if is_admin and user_role in ['SUPER_ADMIN', 'SECURITY_ADMIN']: + if is_admin and user_role in ["SUPER_ADMIN", "SECURITY_ADMIN"]: access_level = SystemInfoLevel.ADMIN - elif user_role in ['SYSTEM_ADMIN', 'SCAN_OPERATOR']: + elif user_role in ["SYSTEM_ADMIN", "SCAN_OPERATOR"]: access_level = SystemInfoLevel.OPERATIONAL - elif user_role in ['COMPLIANCE_OFFICER']: + elif user_role in ["COMPLIANCE_OFFICER"]: access_level = SystemInfoLevel.COMPLIANCE - + context = SystemInfoSanitizationContext( user_id=user_id, user_role=user_role, source_ip=source_ip, access_level=access_level, is_admin=is_admin, - compliance_only=(access_level in [SystemInfoLevel.BASIC, SystemInfoLevel.COMPLIANCE]) + compliance_only=( + access_level in [SystemInfoLevel.BASIC, SystemInfoLevel.COMPLIANCE] + ), ) - + # Apply sanitization - sanitized_info, metadata = _system_info_sanitization_service.sanitize_system_information( - system_info, context + sanitized_info, metadata = ( + _system_info_sanitization_service.sanitize_system_information(system_info, context) ) - + # Add sanitization metadata - sanitized_info['_metadata'] = { - 'sanitization_level': metadata.sanitization_level.value, - 'reconnaissance_filtered': metadata.reconnaissance_filtered, - 'admin_access_used': metadata.admin_access_used, - 'timestamp': metadata.collection_timestamp.isoformat() + sanitized_info["_metadata"] = { + "sanitization_level": metadata.sanitization_level.value, + "reconnaissance_filtered": metadata.reconnaissance_filtered, + "admin_access_used": metadata.admin_access_used, + "timestamp": metadata.collection_timestamp.isoformat(), } - + return sanitized_info - + except Exception as e: logger.error(f"Enhanced system info sanitization failed: {e}") # Fallback to basic sanitization return self._sanitize_system_info_integration(system_info, user_id, source_ip) - - def create_validation_result_sanitizer(self, validation_result: Dict[str, Any]) -> Dict[str, Any]: + + def create_validation_result_sanitizer( + self, validation_result: Dict[str, Any] + ) -> Dict[str, Any]: """Sanitize ValidationResult objects for safe user consumption""" - + sanitized_errors = [] sanitized_warnings = [] - + # Sanitize all errors - for error in validation_result.get('errors', []): + for error in validation_result.get("errors", []): if isinstance(error, dict): sanitized_errors.append(self.sanitize_error(error).dict()) else: # Handle ScanError objects sanitized_errors.append(self.sanitize_error(error.dict()).dict()) - - # Sanitize all warnings - for warning in validation_result.get('warnings', []): + + # Sanitize all warnings + for warning in validation_result.get("warnings", []): if isinstance(warning, dict): sanitized_warnings.append(self.sanitize_error(warning).dict()) else: # Handle ScanError objects sanitized_warnings.append(self.sanitize_error(warning.dict()).dict()) - + # Remove sensitive system info using integrated system sanitization sanitized_system_info = self._sanitize_system_info_integration( - validation_result.get('system_info', {}), - user_id, - source_ip + validation_result.get("system_info", {}), user_id, source_ip ) - + return { - 'can_proceed': validation_result.get('can_proceed', False), - 'errors': sanitized_errors, - 'warnings': sanitized_warnings, - 'pre_flight_duration': validation_result.get('pre_flight_duration', 0.0), - 'system_info': sanitized_system_info, # Sanitized system info - 'validation_checks': validation_result.get('validation_checks', {}) + "can_proceed": validation_result.get("can_proceed", False), + "errors": sanitized_errors, + "warnings": sanitized_warnings, + "pre_flight_duration": validation_result.get("pre_flight_duration", 0.0), + "system_info": sanitized_system_info, # Sanitized system info + "validation_checks": validation_result.get("validation_checks", {}), } - + def get_rate_limit_status(self, source_ip: str) -> Dict[str, Any]: """Get current rate limit status for an IP (for monitoring)""" if source_ip not in self.rate_limit_cache: - return {'is_limited': False, 'error_count': 0} - + return {"is_limited": False, "error_count": 0} + state = self.rate_limit_cache[source_ip] return { - 'is_limited': self._is_rate_limited(source_ip), - 'error_count': state.error_count, - 'is_blocked': state.is_blocked, - 'block_until': state.block_until.isoformat() if state.block_until else None, - 'remaining_errors': max(0, self.MAX_ERRORS_PER_HOUR - state.error_count) + "is_limited": self._is_rate_limited(source_ip), + "error_count": state.error_count, + "is_blocked": state.is_blocked, + "block_until": state.block_until.isoformat() if state.block_until else None, + "remaining_errors": max(0, self.MAX_ERRORS_PER_HOUR - state.error_count), } + # Global instance for dependency injection _sanitization_service = None + def get_error_sanitization_service() -> ErrorSanitizationService: """Get or create the global error sanitization service""" global _sanitization_service if _sanitization_service is None: _sanitization_service = ErrorSanitizationService() - return _sanitization_service \ No newline at end of file + return _sanitization_service diff --git a/backend/app/services/group_scan_service.py b/backend/app/services/group_scan_service.py index cd8207b3..e36e9ece 100644 --- a/backend/app/services/group_scan_service.py +++ b/backend/app/services/group_scan_service.py @@ -1,6 +1,7 @@ """ Group Scan Service - Manages group scan sessions and progress tracking """ + import uuid import json import logging @@ -11,9 +12,14 @@ from sqlalchemy import text from ..models.scan_models import ( - GroupScanSession, GroupScanProgress, GroupScanConfig, - HostScanDetail, HostScanStatus, ScanSessionStatus, - GroupScanSummary, ActiveScanSession + GroupScanSession, + GroupScanProgress, + GroupScanConfig, + HostScanDetail, + HostScanStatus, + ScanSessionStatus, + GroupScanSummary, + ActiveScanSession, ) from ..tasks.scan_tasks import execute_scan_task @@ -23,24 +29,26 @@ class GroupScanProgressTracker: """Real-time progress tracking for group scans""" - + def __init__(self, db: Session): self.db = db - + async def update_host_status( - self, - session_id: str, - host_id: str, - status: str, + self, + session_id: str, + host_id: str, + status: str, scan_id: Optional[str] = None, scan_result_id: Optional[str] = None, error_message: Optional[str] = None, - progress: int = 0 + progress: int = 0, ): """Update individual host scan status within a group scan""" try: # Update host progress in group_scan_host_progress table - self.db.execute(text(""" + self.db.execute( + text( + """ UPDATE group_scan_host_progress SET status = :status, scan_result_id = :scan_result_id, @@ -51,61 +59,70 @@ async def update_host_status( scan_start_time = CASE WHEN :status = 'scanning' AND scan_start_time IS NULL THEN :updated_at ELSE scan_start_time END WHERE session_id = :session_id AND host_id = :host_id - """), { - "session_id": session_id, - "host_id": host_id, - "status": status, - "scan_result_id": scan_result_id, - "error_message": error_message, - "updated_at": datetime.utcnow() - }) - + """ + ), + { + "session_id": session_id, + "host_id": host_id, + "status": status, + "scan_result_id": scan_result_id, + "error_message": error_message, + "updated_at": datetime.utcnow(), + }, + ) + # Update the main scan record if scan_id provided if scan_id: - self.db.execute(text(""" + self.db.execute( + text( + """ UPDATE group_scan_host_progress SET scan_id = :scan_id WHERE session_id = :session_id AND host_id = :host_id - """), { - "session_id": session_id, - "host_id": host_id, - "scan_id": scan_id - }) - + """ + ), + {"session_id": session_id, "host_id": host_id, "scan_id": scan_id}, + ) + # Update overall session progress await self._update_session_progress(session_id) - + self.db.commit() - + logger.debug(f"Updated host {host_id} status to {status} in session {session_id}") - + except Exception as e: logger.error(f"Failed to update host status: {e}") self.db.rollback() raise - + async def _update_session_progress(self, session_id: str): """Update overall session progress based on host statuses""" try: # Get host status counts - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT status, COUNT(*) as count FROM group_scan_host_progress WHERE session_id = :session_id GROUP BY status - """), {"session_id": session_id}) - + """ + ), + {"session_id": session_id}, + ) + status_counts = {row.status: row.count for row in result} - + total_hosts = sum(status_counts.values()) - completed = status_counts.get('completed', 0) - failed = status_counts.get('failed', 0) - scanning = status_counts.get('scanning', 0) - pending = status_counts.get('pending', 0) - cancelled = status_counts.get('cancelled', 0) - + completed = status_counts.get("completed", 0) + failed = status_counts.get("failed", 0) + scanning = status_counts.get("scanning", 0) + pending = status_counts.get("pending", 0) + cancelled = status_counts.get("cancelled", 0) + # Determine overall session status session_status = ScanSessionStatus.IN_PROGRESS if completed + failed + cancelled == total_hosts: @@ -117,85 +134,116 @@ async def _update_session_progress(self, session_id: str): session_status = ScanSessionStatus.CANCELLED elif pending == total_hosts: session_status = ScanSessionStatus.PENDING - + # Calculate progress percentage - progress_percentage = ((completed + failed + cancelled) / total_hosts) * 100 if total_hosts > 0 else 0 - + progress_percentage = ( + ((completed + failed + cancelled) / total_hosts) * 100 if total_hosts > 0 else 0 + ) + # Update session record update_data = { "session_id": session_id, "status": session_status.value, "updated_at": datetime.utcnow(), - "progress_percentage": progress_percentage + "progress_percentage": progress_percentage, } - + # Set completion time if finished - if session_status in [ScanSessionStatus.COMPLETED, ScanSessionStatus.FAILED, ScanSessionStatus.CANCELLED]: + if session_status in [ + ScanSessionStatus.COMPLETED, + ScanSessionStatus.FAILED, + ScanSessionStatus.CANCELLED, + ]: update_data["completed_at"] = datetime.utcnow() - - self.db.execute(text(""" + + self.db.execute( + text( + """ UPDATE group_scan_sessions SET status = :status, updated_at = :updated_at, completed_at = COALESCE(:completed_at, completed_at) WHERE session_id = :session_id - """), update_data) - - logger.debug(f"Updated session {session_id} progress: {progress_percentage:.1f}% complete") - + """ + ), + update_data, + ) + + logger.debug( + f"Updated session {session_id} progress: {progress_percentage:.1f}% complete" + ) + except Exception as e: logger.error(f"Failed to update session progress: {e}") raise - + async def calculate_progress(self, session_id: str) -> GroupScanProgress: """Calculate overall progress for a group scan session""" try: # Get session details - session_result = self.db.execute(text(""" + session_result = self.db.execute( + text( + """ SELECT session_id, group_id, group_name, total_hosts, status, start_time, estimated_completion, updated_at FROM group_scan_sessions WHERE session_id = :session_id - """), {"session_id": session_id}).fetchone() - + """ + ), + {"session_id": session_id}, + ).fetchone() + if not session_result: raise ValueError(f"Session {session_id} not found") - + # Get host status counts - status_result = self.db.execute(text(""" + status_result = self.db.execute( + text( + """ SELECT status, COUNT(*) as count FROM group_scan_host_progress WHERE session_id = :session_id GROUP BY status - """), {"session_id": session_id}) - + """ + ), + {"session_id": session_id}, + ) + status_counts = {row.status: row.count for row in status_result} - - hosts_completed = status_counts.get('completed', 0) - hosts_failed = status_counts.get('failed', 0) - hosts_scanning = status_counts.get('scanning', 0) - hosts_pending = status_counts.get('pending', 0) - - progress_percentage = ((hosts_completed + hosts_failed) / session_result.total_hosts) * 100 \ - if session_result.total_hosts > 0 else 0 - + + hosts_completed = status_counts.get("completed", 0) + hosts_failed = status_counts.get("failed", 0) + hosts_scanning = status_counts.get("scanning", 0) + hosts_pending = status_counts.get("pending", 0) + + progress_percentage = ( + ((hosts_completed + hosts_failed) / session_result.total_hosts) * 100 + if session_result.total_hosts > 0 + else 0 + ) + # Calculate average scan duration if we have completed scans avg_duration = None if hosts_completed > 0: - duration_result = self.db.execute(text(""" + duration_result = self.db.execute( + text( + """ SELECT AVG(EXTRACT(EPOCH FROM (scan_end_time - scan_start_time))) as avg_duration FROM group_scan_host_progress WHERE session_id = :session_id AND status = 'completed' AND scan_start_time IS NOT NULL AND scan_end_time IS NOT NULL - """), {"session_id": session_id}).fetchone() - + """ + ), + {"session_id": session_id}, + ).fetchone() + if duration_result and duration_result.avg_duration: avg_duration = float(duration_result.avg_duration) - + return GroupScanProgress( session_id=session_result.session_id, group_id=session_result.group_id, @@ -210,34 +258,41 @@ async def calculate_progress(self, session_id: str) -> GroupScanProgress: estimated_completion=session_result.estimated_completion, average_scan_duration=avg_duration, started_at=session_result.start_time, - last_updated=session_result.updated_at or session_result.start_time + last_updated=session_result.updated_at or session_result.start_time, ) - + except Exception as e: logger.error(f"Failed to calculate progress for session {session_id}: {e}") raise - + async def estimate_completion(self, session_id: str) -> Optional[datetime]: """Estimate completion time based on current progress""" try: progress = await self.calculate_progress(session_id) - - if progress.status in [ScanSessionStatus.COMPLETED, ScanSessionStatus.FAILED, ScanSessionStatus.CANCELLED]: + + if progress.status in [ + ScanSessionStatus.COMPLETED, + ScanSessionStatus.FAILED, + ScanSessionStatus.CANCELLED, + ]: return None - - if progress.average_scan_duration and progress.hosts_scanning + progress.hosts_pending > 0: + + if ( + progress.average_scan_duration + and progress.hosts_scanning + progress.hosts_pending > 0 + ): remaining_hosts = progress.hosts_scanning + progress.hosts_pending estimated_seconds = remaining_hosts * progress.average_scan_duration return datetime.utcnow() + timedelta(seconds=estimated_seconds) - + # Fallback estimate based on typical scan duration (10 minutes per host) if progress.hosts_scanning + progress.hosts_pending > 0: remaining_hosts = progress.hosts_scanning + progress.hosts_pending estimated_minutes = remaining_hosts * 10 # 10 minutes per host average return datetime.utcnow() + timedelta(minutes=estimated_minutes) - + return None - + except Exception as e: logger.error(f"Failed to estimate completion for session {session_id}: {e}") return None @@ -245,86 +300,104 @@ async def estimate_completion(self, session_id: str) -> Optional[datetime]: class GroupScanService: """Service for managing group scan operations""" - + def __init__(self, db: Session): self.db = db self.progress_tracker = GroupScanProgressTracker(db) - + async def initiate_group_scan( - self, - group_id: int, - user_id: int, - scan_config: Optional[GroupScanConfig] = None + self, group_id: int, user_id: int, scan_config: Optional[GroupScanConfig] = None ) -> GroupScanSession: """Initiate a group scan for all hosts in a group""" try: # Get group details - group_result = self.db.execute(text(""" + group_result = self.db.execute( + text( + """ SELECT id, name, description, scap_content_id, default_profile_id FROM host_groups WHERE id = :group_id - """), {"group_id": group_id}).fetchone() - + """ + ), + {"group_id": group_id}, + ).fetchone() + if not group_result: raise ValueError(f"Group {group_id} not found") - + # Get hosts in the group - hosts_result = self.db.execute(text(""" + hosts_result = self.db.execute( + text( + """ SELECT h.id, h.hostname, h.display_name, h.ip_address, h.username, h.auth_method, h.port FROM hosts h JOIN host_group_memberships hgm ON h.id = hgm.host_id WHERE hgm.group_id = :group_id AND h.is_active = true ORDER BY h.hostname - """), {"group_id": group_id}) - + """ + ), + {"group_id": group_id}, + ) + hosts = hosts_result.fetchall() if not hosts: raise ValueError(f"No active hosts found in group {group_id}") - + # Use group defaults or provided config if not scan_config: scan_config = GroupScanConfig() - + # Use group's default SCAP content if not specified if not scan_config.content_id and group_result.scap_content_id: scan_config.content_id = group_result.scap_content_id - + if not scan_config.profile_id and group_result.default_profile_id: scan_config.profile_id = group_result.default_profile_id - + # Validate SCAP content exists if scan_config.content_id: - content_result = self.db.execute(text(""" + content_result = self.db.execute( + text( + """ SELECT id, name, file_path, profiles FROM scap_content WHERE id = :content_id - """), {"content_id": scan_config.content_id}).fetchone() - + """ + ), + {"content_id": scan_config.content_id}, + ).fetchone() + if not content_result: raise ValueError(f"SCAP content {scan_config.content_id} not found") - + # Validate profile exists if specified if scan_config.profile_id and content_result.profiles: profiles = json.loads(content_result.profiles) profile_ids = [p.get("id") for p in profiles if p.get("id")] if scan_config.profile_id not in profile_ids: - raise ValueError(f"Profile {scan_config.profile_id} not found in SCAP content") + raise ValueError( + f"Profile {scan_config.profile_id} not found in SCAP content" + ) else: # Use default SCAP content - default_content = self.db.execute(text(""" + default_content = self.db.execute( + text( + """ SELECT id, name, file_path, profiles FROM scap_content ORDER BY uploaded_at DESC LIMIT 1 - """)).fetchone() - + """ + ) + ).fetchone() + if not default_content: raise ValueError("No SCAP content available") - + scan_config.content_id = default_content.id if not scan_config.profile_id and default_content.profiles: profiles = json.loads(default_content.profiles) if profiles: scan_config.profile_id = profiles[0].get("id") - + # Create group scan session session_id = str(uuid.uuid4()) session_data = { @@ -335,41 +408,53 @@ async def initiate_group_scan( "initiated_by": user_id, "start_time": datetime.utcnow(), "status": ScanSessionStatus.PENDING.value, - "scan_config": json.dumps(scan_config.dict()) if scan_config else None + "scan_config": json.dumps(scan_config.dict()) if scan_config else None, } - + # Estimate completion time (10 minutes per host average, with stagger delay) estimated_minutes = len(hosts) * 10 + (len(hosts) * scan_config.stagger_delay / 60) - session_data["estimated_completion"] = datetime.utcnow() + timedelta(minutes=estimated_minutes) - - self.db.execute(text(""" + session_data["estimated_completion"] = datetime.utcnow() + timedelta( + minutes=estimated_minutes + ) + + self.db.execute( + text( + """ INSERT INTO group_scan_sessions (session_id, group_id, group_name, total_hosts, initiated_by, start_time, estimated_completion, status, scan_config, created_at, updated_at) VALUES (:session_id, :group_id, :group_name, :total_hosts, :initiated_by, :start_time, :estimated_completion, :status, :scan_config, :start_time, :start_time) - """), session_data) - + """ + ), + session_data, + ) + # Create host progress tracking records for host in hosts: - self.db.execute(text(""" + self.db.execute( + text( + """ INSERT INTO group_scan_host_progress (session_id, host_id, host_name, host_ip, status, created_at, updated_at) VALUES (:session_id, :host_id, :host_name, :host_ip, :status, :created_at, :updated_at) - """), { - "session_id": session_id, - "host_id": str(host.id), - "host_name": host.display_name or host.hostname, - "host_ip": host.ip_address or host.hostname, - "status": HostScanStatus.PENDING.value, - "created_at": datetime.utcnow(), - "updated_at": datetime.utcnow() - }) - + """ + ), + { + "session_id": session_id, + "host_id": str(host.id), + "host_name": host.display_name or host.hostname, + "host_ip": host.ip_address or host.hostname, + "status": HostScanStatus.PENDING.value, + "created_at": datetime.utcnow(), + "updated_at": datetime.utcnow(), + }, + ) + self.db.commit() - + # Create session object session = GroupScanSession( session_id=session_id, @@ -381,196 +466,224 @@ async def initiate_group_scan( estimated_completion=session_data["estimated_completion"], status=ScanSessionStatus.PENDING, hosts_pending=[str(host.id) for host in hosts], - scan_config=scan_config + scan_config=scan_config, + ) + + logger.info( + f"Created group scan session {session_id} for group {group_id} with {len(hosts)} hosts" ) - - logger.info(f"Created group scan session {session_id} for group {group_id} with {len(hosts)} hosts") return session - + except Exception as e: self.db.rollback() logger.error(f"Failed to initiate group scan: {e}") raise - + async def start_group_scan_execution(self, session_id: str) -> bool: """Start executing scans for all hosts in a group scan session""" try: # Get session details - session_result = self.db.execute(text(""" + session_result = self.db.execute( + text( + """ SELECT session_id, group_id, group_name, scan_config, status FROM group_scan_sessions WHERE session_id = :session_id - """), {"session_id": session_id}).fetchone() - + """ + ), + {"session_id": session_id}, + ).fetchone() + if not session_result: raise ValueError(f"Session {session_id} not found") - + if session_result.status != ScanSessionStatus.PENDING.value: raise ValueError(f"Session {session_id} is not in pending status") - + # Parse scan config scan_config = GroupScanConfig() if session_result.scan_config: config_data = json.loads(session_result.scan_config) scan_config = GroupScanConfig(**config_data) - + # Get SCAP content details - content_result = self.db.execute(text(""" + content_result = self.db.execute( + text( + """ SELECT id, name, file_path FROM scap_content WHERE id = :content_id - """), {"content_id": scan_config.content_id}).fetchone() - + """ + ), + {"content_id": scan_config.content_id}, + ).fetchone() + if not content_result: raise ValueError(f"SCAP content {scan_config.content_id} not found") - + # Get pending hosts - hosts_result = self.db.execute(text(""" + hosts_result = self.db.execute( + text( + """ SELECT p.host_id, p.host_name, p.host_ip, h.hostname, h.port, h.username, h.auth_method, h.encrypted_credentials FROM group_scan_host_progress p JOIN hosts h ON p.host_id = h.id WHERE p.session_id = :session_id AND p.status = 'pending' ORDER BY p.host_name - """), {"session_id": session_id}) - + """ + ), + {"session_id": session_id}, + ) + hosts = hosts_result.fetchall() if not hosts: logger.warning(f"No pending hosts found for session {session_id}") return False - + # Update session status to in progress - self.db.execute(text(""" + self.db.execute( + text( + """ UPDATE group_scan_sessions SET status = :status, updated_at = :updated_at WHERE session_id = :session_id - """), { - "session_id": session_id, - "status": ScanSessionStatus.IN_PROGRESS.value, - "updated_at": datetime.utcnow() - }) + """ + ), + { + "session_id": session_id, + "status": ScanSessionStatus.IN_PROGRESS.value, + "updated_at": datetime.utcnow(), + }, + ) self.db.commit() - + # Start scans with stagger delay - asyncio.create_task(self._execute_staggered_scans( - session_id, hosts, content_result, scan_config - )) - - logger.info(f"Started group scan execution for session {session_id} with {len(hosts)} hosts") + asyncio.create_task( + self._execute_staggered_scans(session_id, hosts, content_result, scan_config) + ) + + logger.info( + f"Started group scan execution for session {session_id} with {len(hosts)} hosts" + ) return True - + except Exception as e: logger.error(f"Failed to start group scan execution: {e}") self.db.rollback() raise - + async def _execute_staggered_scans( - self, - session_id: str, - hosts: List[Any], - content_result: Any, - scan_config: GroupScanConfig + self, session_id: str, hosts: List[Any], content_result: Any, scan_config: GroupScanConfig ): """Execute scans with staggered start times""" try: concurrent_scans = 0 max_concurrent = scan_config.max_concurrent or 5 - + for i, host in enumerate(hosts): # Wait for available slot while concurrent_scans >= max_concurrent: await asyncio.sleep(5) # Check every 5 seconds # Update concurrent count by checking active scans - active_result = self.db.execute(text(""" + active_result = self.db.execute( + text( + """ SELECT COUNT(*) as active_count FROM group_scan_host_progress WHERE session_id = :session_id AND status = 'scanning' - """), {"session_id": session_id}) + """ + ), + {"session_id": session_id}, + ) concurrent_scans = active_result.fetchone().active_count - + # Create individual scan scan_id = await self._create_individual_scan( session_id, host, content_result, scan_config ) - + if scan_id: # Start the scan task - asyncio.create_task(self._execute_host_scan( - session_id, scan_id, host, content_result, scan_config - )) + asyncio.create_task( + self._execute_host_scan( + session_id, scan_id, host, content_result, scan_config + ) + ) concurrent_scans += 1 - + # Stagger delay between scans (except for last scan) if i < len(hosts) - 1: await asyncio.sleep(scan_config.stagger_delay) - + logger.info(f"Initiated all scans for session {session_id}") - + except Exception as e: logger.error(f"Error in staggered scan execution: {e}") - + async def _create_individual_scan( - self, - session_id: str, - host: Any, - content_result: Any, - scan_config: GroupScanConfig + self, session_id: str, host: Any, content_result: Any, scan_config: GroupScanConfig ) -> Optional[str]: """Create an individual scan record for a host in the group scan""" try: scan_id = str(uuid.uuid4()) scan_name = f"Group Scan - {host.host_name}" - + # Create scan record - self.db.execute(text(""" + self.db.execute( + text( + """ INSERT INTO scans (id, name, host_id, content_id, profile_id, status, progress, scan_options, started_by, started_at, remediation_requested, verification_scan) VALUES (:id, :name, :host_id, :content_id, :profile_id, :status, :progress, :scan_options, :started_by, :started_at, :remediation_requested, :verification_scan) - """), { - "id": scan_id, - "name": scan_name, - "host_id": host.host_id, - "content_id": scan_config.content_id, - "profile_id": scan_config.profile_id, - "status": "pending", - "progress": 0, - "scan_options": json.dumps({ - "group_scan": True, - "session_id": session_id, - **scan_config.scan_options - }), - "started_by": 1, # System user for group scans - "started_at": datetime.utcnow(), - "remediation_requested": False, - "verification_scan": False - }) - + """ + ), + { + "id": scan_id, + "name": scan_name, + "host_id": host.host_id, + "content_id": scan_config.content_id, + "profile_id": scan_config.profile_id, + "status": "pending", + "progress": 0, + "scan_options": json.dumps( + {"group_scan": True, "session_id": session_id, **scan_config.scan_options} + ), + "started_by": 1, # System user for group scans + "started_at": datetime.utcnow(), + "remediation_requested": False, + "verification_scan": False, + }, + ) + # Update host progress with scan ID await self.progress_tracker.update_host_status( session_id, host.host_id, HostScanStatus.SCANNING.value, scan_id ) - + self.db.commit() return scan_id - + except Exception as e: logger.error(f"Failed to create individual scan for host {host.host_id}: {e}") self.db.rollback() return None - + async def _execute_host_scan( - self, - session_id: str, - scan_id: str, - host: Any, - content_result: Any, - scan_config: GroupScanConfig + self, + session_id: str, + scan_id: str, + host: Any, + content_result: Any, + scan_config: GroupScanConfig, ): """Execute scan for a single host within a group scan""" try: - logger.info(f"Starting host scan {scan_id} for {host.host_name} in session {session_id}") - + logger.info( + f"Starting host scan {scan_id} for {host.host_name} in session {session_id}" + ) + # Prepare host data host_data = { "id": host.host_id, @@ -578,9 +691,9 @@ async def _execute_host_scan( "port": host.port, "username": host.username, "auth_method": host.auth_method, - "encrypted_credentials": host.encrypted_credentials + "encrypted_credentials": host.encrypted_credentials, } - + # Execute the scan task # This will run in the background and update progress via the scan task execute_scan_task( @@ -591,28 +704,29 @@ async def _execute_host_scan( scan_options={ "group_scan": True, "session_id": session_id, - **scan_config.scan_options - } + **scan_config.scan_options, + }, ) - + # The scan task will handle updating the host status when complete - + except Exception as e: logger.error(f"Failed to execute host scan {scan_id}: {e}") # Update host status to failed await self.progress_tracker.update_host_status( - session_id, host.host_id, HostScanStatus.FAILED.value, - scan_id, error_message=str(e) + session_id, host.host_id, HostScanStatus.FAILED.value, scan_id, error_message=str(e) ) - + async def get_scan_progress(self, session_id: str) -> GroupScanProgress: """Get real-time progress of a group scan""" return await self.progress_tracker.calculate_progress(session_id) - + async def get_host_scan_details(self, session_id: str) -> List[HostScanDetail]: """Get detailed status of each host in a group scan""" try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT p.host_id, p.host_name, p.host_ip, p.status, p.scan_id, p.scan_start_time, p.scan_end_time, p.error_message, h.hostname, s.progress, @@ -623,8 +737,11 @@ async def get_host_scan_details(self, session_id: str) -> List[HostScanDetail]: LEFT JOIN scan_results sr ON s.id = sr.scan_id WHERE p.session_id = :session_id ORDER BY p.host_name - """), {"session_id": session_id}) - + """ + ), + {"session_id": session_id}, + ) + host_details = [] for row in result: scan_results = None @@ -633,85 +750,101 @@ async def get_host_scan_details(self, session_id: str) -> List[HostScanDetail]: "total_rules": row.total_rules, "passed_rules": row.passed_rules, "failed_rules": row.failed_rules, - "score": row.score + "score": row.score, } - - host_details.append(HostScanDetail( - host_id=row.host_id, - host_name=row.host_name, - hostname=row.hostname, - ip_address=row.host_ip, - status=HostScanStatus(row.status), - scan_id=row.scan_id, - scan_start_time=row.scan_start_time, - scan_end_time=row.scan_end_time, - progress=row.progress or 0, - error_message=row.error_message, - scan_results=scan_results - )) - + + host_details.append( + HostScanDetail( + host_id=row.host_id, + host_name=row.host_name, + hostname=row.hostname, + ip_address=row.host_ip, + status=HostScanStatus(row.status), + scan_id=row.scan_id, + scan_start_time=row.scan_start_time, + scan_end_time=row.scan_end_time, + progress=row.progress or 0, + error_message=row.error_message, + scan_results=scan_results, + ) + ) + return host_details - + except Exception as e: logger.error(f"Failed to get host scan details: {e}") raise - + async def cancel_group_scan(self, session_id: str) -> bool: """Cancel an ongoing group scan""" try: # Update session status - self.db.execute(text(""" + self.db.execute( + text( + """ UPDATE group_scan_sessions SET status = :status, updated_at = :updated_at, completed_at = :completed_at WHERE session_id = :session_id AND status IN ('pending', 'in_progress') - """), { - "session_id": session_id, - "status": ScanSessionStatus.CANCELLED.value, - "updated_at": datetime.utcnow(), - "completed_at": datetime.utcnow() - }) - + """ + ), + { + "session_id": session_id, + "status": ScanSessionStatus.CANCELLED.value, + "updated_at": datetime.utcnow(), + "completed_at": datetime.utcnow(), + }, + ) + # Cancel pending and running scans - self.db.execute(text(""" + self.db.execute( + text( + """ UPDATE scans SET status = 'cancelled', error_message = 'Group scan cancelled by user' WHERE id IN ( SELECT scan_id FROM group_scan_host_progress WHERE session_id = :session_id AND scan_id IS NOT NULL ) AND status IN ('pending', 'running') - """), {"session_id": session_id}) - + """ + ), + {"session_id": session_id}, + ) + # Update host progress statuses - self.db.execute(text(""" + self.db.execute( + text( + """ UPDATE group_scan_host_progress SET status = 'cancelled', updated_at = :updated_at WHERE session_id = :session_id AND status IN ('pending', 'scanning') - """), { - "session_id": session_id, - "updated_at": datetime.utcnow() - }) - + """ + ), + {"session_id": session_id, "updated_at": datetime.utcnow()}, + ) + self.db.commit() - + logger.info(f"Cancelled group scan session {session_id}") return True - + except Exception as e: logger.error(f"Failed to cancel group scan: {e}") self.db.rollback() return False - + async def get_active_scans(self, user_id: Optional[int] = None) -> List[ActiveScanSession]: """Get all active scan sessions""" try: where_clause = "" params = {} - + if user_id: where_clause = "WHERE s.initiated_by = :user_id" params["user_id"] = user_id - - result = self.db.execute(text(f""" + + result = self.db.execute( + text( + f""" SELECT s.session_id, s.group_id, s.group_name, s.status, s.total_hosts, s.start_time, s.estimated_completion, s.initiated_by, COUNT(CASE WHEN p.status = 'completed' THEN 1 END) as hosts_completed @@ -722,27 +855,34 @@ async def get_active_scans(self, user_id: Optional[int] = None) -> List[ActiveSc s.start_time, s.estimated_completion, s.initiated_by HAVING s.status IN ('pending', 'in_progress') ORDER BY s.start_time DESC - """), params) - + """ + ), + params, + ) + active_sessions = [] for row in result: - progress_percentage = (row.hosts_completed / row.total_hosts) * 100 if row.total_hosts > 0 else 0 - - active_sessions.append(ActiveScanSession( - session_id=row.session_id, - group_id=row.group_id, - group_name=row.group_name, - status=ScanSessionStatus(row.status), - progress_percentage=progress_percentage, - hosts_completed=row.hosts_completed, - total_hosts=row.total_hosts, - started_at=row.start_time, - estimated_completion=row.estimated_completion, - initiated_by=row.initiated_by - )) - + progress_percentage = ( + (row.hosts_completed / row.total_hosts) * 100 if row.total_hosts > 0 else 0 + ) + + active_sessions.append( + ActiveScanSession( + session_id=row.session_id, + group_id=row.group_id, + group_name=row.group_name, + status=ScanSessionStatus(row.status), + progress_percentage=progress_percentage, + hosts_completed=row.hosts_completed, + total_hosts=row.total_hosts, + started_at=row.start_time, + estimated_completion=row.estimated_completion, + initiated_by=row.initiated_by, + ) + ) + return active_sessions - + except Exception as e: logger.error(f"Failed to get active scans: {e}") - return [] \ No newline at end of file + return [] diff --git a/backend/app/services/group_validation_service.py b/backend/app/services/group_validation_service.py index bb55b494..1d23265c 100644 --- a/backend/app/services/group_validation_service.py +++ b/backend/app/services/group_validation_service.py @@ -21,6 +21,7 @@ logger = logging.getLogger(__name__) + # Define ValidationError as a simple exception class class ValidationError(Exception): def __init__(self, message, category=None, severity=None, context=None): @@ -32,6 +33,7 @@ def __init__(self, message, category=None, severity=None, context=None): class OSFamily: """OS family constants""" + RHEL = "rhel" CENTOS = "centos" FEDORA = "fedora" @@ -45,7 +47,7 @@ class OSFamily: FREEBSD = "freebsd" OPENBSD = "openbsd" SOLARIS = "solaris" - + # OS family groupings for compatibility RHEL_FAMILY = {RHEL, CENTOS, FEDORA} DEBIAN_FAMILY = {UBUNTU, DEBIAN} @@ -56,21 +58,18 @@ class OSFamily: class GroupValidationService: """Service for intelligent group validation and compatibility checking""" - + def __init__(self, db: Session): self.db = db self.sanitization_service = SystemInfoSanitizationService() self.cache_duration = timedelta(hours=24) # Cache compatibility results for 24 hours - + def validate_host_group_compatibility( - self, - host_ids: List[str], - group_id: int, - user_role: Optional[str] = None + self, host_ids: List[str], group_id: int, user_role: Optional[str] = None ) -> Dict[str, Any]: """ Validate compatibility between hosts and a group - + Returns detailed validation results including: - Compatible hosts - Incompatible hosts with reasons @@ -83,18 +82,18 @@ def validate_host_group_compatibility( message=f"Group {group_id} not found", category=ErrorCategory.NOT_FOUND, severity=ErrorSeverity.ERROR, - context={"group_id": group_id} + context={"group_id": group_id}, ) - + hosts = self.db.query(Host).filter(Host.id.in_(host_ids)).all() if not hosts: raise ValidationError( message="No hosts found", category=ErrorCategory.NOT_FOUND, severity=ErrorSeverity.ERROR, - context={"host_ids": host_ids} + context={"host_ids": host_ids}, ) - + results = { "group": { "id": group.id, @@ -102,7 +101,7 @@ def validate_host_group_compatibility( "os_family": group.os_family, "os_version_pattern": group.os_version_pattern, "compliance_framework": group.compliance_framework, - "scap_content_id": group.scap_content_id + "scap_content_id": group.scap_content_id, }, "compatible": [], "incompatible": [], @@ -112,14 +111,14 @@ def validate_host_group_compatibility( "total_hosts": len(hosts), "compatible_count": 0, "incompatible_count": 0, - "compatibility_score": 0.0 - } + "compatibility_score": 0.0, + }, } - + # Check each host for host in hosts: compatibility = self._check_host_compatibility(host, group, user_role) - + host_info = { "id": str(host.id), "hostname": host.hostname, @@ -128,9 +127,9 @@ def validate_host_group_compatibility( "os_version": host.os_version, "architecture": host.architecture, "compatibility_score": compatibility["score"], - "validation_details": compatibility["details"] + "validation_details": compatibility["details"], } - + if compatibility["is_compatible"]: results["compatible"].append(host_info) results["summary"]["compatible_count"] += 1 @@ -138,31 +137,30 @@ def validate_host_group_compatibility( host_info["reasons"] = compatibility["reasons"] results["incompatible"].append(host_info) results["summary"]["incompatible_count"] += 1 - + # Generate suggestions for incompatible hosts suggestions = self._generate_group_suggestions(host) if suggestions: results["suggestions"][str(host.id)] = suggestions - + # Add warnings if any if compatibility.get("warnings"): results["warnings"].extend(compatibility["warnings"]) - + # Calculate overall compatibility score if hosts: - total_score = sum(h["compatibility_score"] for h in results["compatible"] + results["incompatible"]) + total_score = sum( + h["compatibility_score"] for h in results["compatible"] + results["incompatible"] + ) results["summary"]["compatibility_score"] = total_score / len(hosts) - + # Cache the results self._cache_compatibility_results(host_ids, group_id, results) - + return results - + def _check_host_compatibility( - self, - host: Host, - group: HostGroup, - user_role: Optional[str] = None + self, host: Host, group: HostGroup, user_role: Optional[str] = None ) -> Dict[str, Any]: """Check compatibility between a host and a group""" compatibility = { @@ -170,13 +168,13 @@ def _check_host_compatibility( "score": 100.0, "reasons": [], "warnings": [], - "details": {} + "details": {}, } - + # Detect OS information if not available if not host.os_family or not host.os_version: self._detect_host_os_info(host) - + # Check OS family compatibility if group.os_family: os_check = self._check_os_family_compatibility(host, group) @@ -187,7 +185,7 @@ def _check_host_compatibility( compatibility["score"] *= 0.0 # Complete mismatch else: compatibility["score"] *= os_check["score"] - + # Check OS version compatibility if group.os_version_pattern: version_check = self._check_os_version_compatibility(host, group) @@ -198,7 +196,7 @@ def _check_host_compatibility( compatibility["score"] *= 0.2 # Severe penalty else: compatibility["score"] *= version_check["score"] - + # Check architecture compatibility if group.architecture: arch_check = self._check_architecture_compatibility(host, group) @@ -206,7 +204,7 @@ def _check_host_compatibility( if not arch_check["compatible"]: compatibility["warnings"].append(arch_check["reason"]) compatibility["score"] *= 0.8 # Minor penalty - + # Check SCAP content compatibility if group.scap_content_id: scap_check = self._check_scap_content_compatibility(host, group) @@ -217,7 +215,7 @@ def _check_host_compatibility( compatibility["score"] *= 0.1 # Severe penalty else: compatibility["score"] *= scap_check["score"] - + # Apply custom validation rules if any if group.validation_rules: custom_check = self._apply_custom_validation_rules(host, group) @@ -230,70 +228,64 @@ def _check_host_compatibility( elif rule_result["severity"] == "warning" and not rule_result["passed"]: compatibility["warnings"].append(rule_result["message"]) compatibility["score"] *= 0.9 - + return compatibility - + def _check_os_family_compatibility(self, host: Host, group: HostGroup) -> Dict[str, Any]: """Check if host OS family matches group requirements""" - result = { - "compatible": True, - "score": 1.0, - "reason": "" - } - + result = {"compatible": True, "score": 1.0, "reason": ""} + if not host.os_family: result["compatible"] = False result["score"] = 0.0 result["reason"] = f"Host {host.hostname} OS family not detected" return result - + # Direct match if host.os_family == group.os_family: result["score"] = 1.0 return result - + # Check OS family groupings (e.g., RHEL and CentOS are compatible) host_family_groups = [] group_family_groups = [] - + for family_name, family_members in [ ("RHEL_FAMILY", OSFamily.RHEL_FAMILY), ("DEBIAN_FAMILY", OSFamily.DEBIAN_FAMILY), ("SUSE_FAMILY", OSFamily.SUSE_FAMILY), ("WINDOWS_FAMILY", OSFamily.WINDOWS_FAMILY), - ("BSD_FAMILY", OSFamily.BSD_FAMILY) + ("BSD_FAMILY", OSFamily.BSD_FAMILY), ]: if host.os_family in family_members: host_family_groups.append(family_name) if group.os_family in family_members: group_family_groups.append(family_name) - + # Check if they belong to the same family group if set(host_family_groups) & set(group_family_groups): result["score"] = 0.9 # High compatibility within same family return result - + # No compatibility result["compatible"] = False result["score"] = 0.0 - result["reason"] = f"Host OS {host.os_family} incompatible with group requirement {group.os_family}" - + result["reason"] = ( + f"Host OS {host.os_family} incompatible with group requirement {group.os_family}" + ) + return result - + def _check_os_version_compatibility(self, host: Host, group: HostGroup) -> Dict[str, Any]: """Check if host OS version matches group pattern""" - result = { - "compatible": True, - "score": 1.0, - "reason": "" - } - + result = {"compatible": True, "score": 1.0, "reason": ""} + if not host.os_version: result["compatible"] = False result["score"] = 0.0 result["reason"] = f"Host {host.hostname} OS version not detected" return result - + try: # Convert pattern to regex pattern = group.os_version_pattern.replace("*", ".*").replace("?", ".") @@ -302,36 +294,34 @@ def _check_os_version_compatibility(self, host: Host, group: HostGroup) -> Dict[ else: result["compatible"] = False result["score"] = 0.0 - result["reason"] = f"Host OS version {host.os_version} doesn't match pattern {group.os_version_pattern}" + result["reason"] = ( + f"Host OS version {host.os_version} doesn't match pattern {group.os_version_pattern}" + ) except Exception as e: logger.warning(f"Invalid version pattern {group.os_version_pattern}: {e}") result["compatible"] = False result["score"] = 0.0 result["reason"] = f"Invalid version pattern: {group.os_version_pattern}" - + return result - + def _check_architecture_compatibility(self, host: Host, group: HostGroup) -> Dict[str, Any]: """Check if host architecture matches group requirements""" - result = { - "compatible": True, - "score": 1.0, - "reason": "" - } - + result = {"compatible": True, "score": 1.0, "reason": ""} + if not host.architecture: result["score"] = 0.8 # Minor penalty result["reason"] = f"Host {host.hostname} architecture not detected" return result - + # Normalize architectures host_arch = host.architecture.lower() group_arch = group.architecture.lower() - + # Direct match if host_arch == group_arch: return result - + # Check compatible architectures arch_compatibility = { "x86_64": ["amd64", "x64"], @@ -340,40 +330,38 @@ def _check_architecture_compatibility(self, host: Host, group: HostGroup) -> Dic "i386": ["i686", "x86"], "i686": ["i386", "x86"], "aarch64": ["arm64"], - "arm64": ["aarch64"] + "arm64": ["aarch64"], } - + compatible_archs = arch_compatibility.get(host_arch, []) if group_arch in compatible_archs: result["score"] = 0.95 return result - + # Not compatible result["compatible"] = False result["score"] = 0.0 - result["reason"] = f"Host architecture {host.architecture} incompatible with group requirement {group.architecture}" - + result["reason"] = ( + f"Host architecture {host.architecture} incompatible with group requirement {group.architecture}" + ) + return result - + def _check_scap_content_compatibility(self, host: Host, group: HostGroup) -> Dict[str, Any]: """Check if SCAP content is compatible with host""" - result = { - "compatible": True, - "score": 1.0, - "reason": "" - } - + result = {"compatible": True, "score": 1.0, "reason": ""} + # Get SCAP content - scap_content = self.db.query(ScapContent).filter( - ScapContent.id == group.scap_content_id - ).first() - + scap_content = ( + self.db.query(ScapContent).filter(ScapContent.id == group.scap_content_id).first() + ) + if not scap_content: result["compatible"] = False result["score"] = 0.0 result["reason"] = "SCAP content not found" return result - + # Check OS family compatibility if scap_content.os_family: if host.os_family != scap_content.os_family: @@ -384,109 +372,126 @@ def _check_scap_content_compatibility(self, host: Host, group: HostGroup) -> Dic OSFamily.DEBIAN_FAMILY, OSFamily.SUSE_FAMILY, OSFamily.WINDOWS_FAMILY, - OSFamily.BSD_FAMILY + OSFamily.BSD_FAMILY, ]: - if host.os_family in family_members and scap_content.os_family in family_members: + if ( + host.os_family in family_members + and scap_content.os_family in family_members + ): compatible = True result["score"] = 0.9 break - + if not compatible: result["compatible"] = False result["score"] = 0.0 - result["reason"] = f"SCAP content for {scap_content.os_family} incompatible with host OS {host.os_family}" + result["reason"] = ( + f"SCAP content for {scap_content.os_family} incompatible with host OS {host.os_family}" + ) return result - + # Check OS version compatibility if scap_content.os_version and host.os_version: content_version = scap_content.os_version.split(".")[0] # Major version host_version = host.os_version.split(".")[0] - + if content_version != host_version: # Check if it's a minor version difference try: content_major = int(content_version) host_major = int(host_version) - + if abs(content_major - host_major) > 1: result["compatible"] = False result["score"] = 0.0 - result["reason"] = f"SCAP content for version {scap_content.os_version} incompatible with host version {host.os_version}" + result["reason"] = ( + f"SCAP content for version {scap_content.os_version} incompatible with host version {host.os_version}" + ) else: result["score"] = 0.7 # Penalty for version mismatch except: pass - + return result - + def _apply_custom_validation_rules(self, host: Host, group: HostGroup) -> List[Dict[str, Any]]: """Apply custom validation rules defined for the group""" results = [] - + if not group.validation_rules: return results - + try: - rules = json.loads(group.validation_rules) if isinstance(group.validation_rules, str) else group.validation_rules + rules = ( + json.loads(group.validation_rules) + if isinstance(group.validation_rules, str) + else group.validation_rules + ) except: logger.error(f"Failed to parse validation rules for group {group.id}") return results - + for rule in rules: rule_result = { "rule_name": rule.get("name", "Unknown"), "passed": True, "message": "", - "severity": rule.get("severity", "warning") + "severity": rule.get("severity", "warning"), } - + try: # Evaluate rule based on type rule_type = rule.get("type") expression = rule.get("expression", "") - + if rule_type == "regex": field = rule.get("field", "hostname") value = getattr(host, field, "") if not re.match(expression, str(value)): rule_result["passed"] = False - rule_result["message"] = rule.get("error_message", f"Field {field} doesn't match pattern") - + rule_result["message"] = rule.get( + "error_message", f"Field {field} doesn't match pattern" + ) + elif rule_type == "range": field = rule.get("field") value = getattr(host, field, None) min_val = rule.get("min") max_val = rule.get("max") - + if value is not None: if min_val is not None and value < min_val: rule_result["passed"] = False - rule_result["message"] = rule.get("error_message", f"{field} below minimum") + rule_result["message"] = rule.get( + "error_message", f"{field} below minimum" + ) elif max_val is not None and value > max_val: rule_result["passed"] = False - rule_result["message"] = rule.get("error_message", f"{field} above maximum") - + rule_result["message"] = rule.get( + "error_message", f"{field} above maximum" + ) + elif rule_type == "custom": # For complex custom rules, we'd evaluate them here # For now, just log that we encountered a custom rule logger.info(f"Custom rule {rule.get('name')} for group {group.id}") - + except Exception as e: logger.error(f"Failed to evaluate rule {rule.get('name')}: {e}") rule_result["passed"] = False rule_result["message"] = "Rule evaluation failed" - + results.append(rule_result) - + return results - + def _detect_host_os_info(self, host: Host) -> None: """Detect and update host OS information""" if not host.operating_system: return - + os_string = host.operating_system.lower() - + # Detect OS family os_family_patterns = { OSFamily.RHEL: r"red\s*hat|rhel", @@ -501,88 +506,89 @@ def _detect_host_os_info(self, host: Host) -> None: OSFamily.MACOS: r"mac\s*os|darwin", OSFamily.FREEBSD: r"freebsd", OSFamily.OPENBSD: r"openbsd", - OSFamily.SOLARIS: r"solaris|sunos" + OSFamily.SOLARIS: r"solaris|sunos", } - + for family, pattern in os_family_patterns.items(): if re.search(pattern, os_string): host.os_family = family break - + # Detect OS version version_match = re.search(r"(\d+\.?\d*)", os_string) if version_match: host.os_version = version_match.group(1) - + # Detect architecture if present arch_patterns = { "x86_64": r"x86_64|x64|amd64", "i386": r"i[3-6]86|x86(?!_64)", "aarch64": r"aarch64|arm64", - "ppc64le": r"ppc64le|powerpc64le" + "ppc64le": r"ppc64le|powerpc64le", } - + for arch, pattern in arch_patterns.items(): if re.search(pattern, os_string): host.architecture = arch break - + # Update last OS detection time host.last_os_detection = datetime.utcnow() - + # Commit changes self.db.add(host) self.db.commit() - + def _generate_group_suggestions(self, host: Host) -> List[Dict[str, Any]]: """Generate group suggestions for an incompatible host""" suggestions = [] - + # Find groups with matching OS family - matching_groups = self.db.query(HostGroup).filter( - HostGroup.os_family == host.os_family - ).all() - + matching_groups = ( + self.db.query(HostGroup).filter(HostGroup.os_family == host.os_family).all() + ) + for group in matching_groups: # Calculate compatibility score compatibility = self._check_host_compatibility(host, group) - + if compatibility["is_compatible"]: - suggestions.append({ - "group_id": group.id, - "group_name": group.name, - "compatibility_score": compatibility["score"], - "compliance_framework": group.compliance_framework, - "reason": f"Compatible {host.os_family} group" - }) - + suggestions.append( + { + "group_id": group.id, + "group_name": group.name, + "compatibility_score": compatibility["score"], + "compliance_framework": group.compliance_framework, + "reason": f"Compatible {host.os_family} group", + } + ) + # Sort by compatibility score suggestions.sort(key=lambda x: x["compatibility_score"], reverse=True) - + # Return top 3 suggestions return suggestions[:3] - + def _cache_compatibility_results( - self, - host_ids: List[str], - group_id: int, - results: Dict[str, Any] + self, host_ids: List[str], group_id: int, results: Dict[str, Any] ) -> None: """Cache compatibility results for performance""" # This would be implemented with Redis or similar caching solution # For now, we'll just log that we would cache the results - logger.info(f"Would cache compatibility results for {len(host_ids)} hosts with group {group_id}") - + logger.info( + f"Would cache compatibility results for {len(host_ids)} hosts with group {group_id}" + ) + def create_smart_group_from_hosts( - self, - host_ids: List[str], + self, + host_ids: List[str], group_name: str, description: Optional[str] = None, - created_by: Optional[int] = None + created_by: Optional[int] = None, ) -> Dict[str, Any]: """ Create a smart group automatically based on host characteristics - + Analyzes the selected hosts and creates a group with appropriate validation rules and SCAP content assignments """ @@ -592,42 +598,42 @@ def create_smart_group_from_hosts( message="No hosts found", category=ErrorCategory.NOT_FOUND, severity=ErrorSeverity.ERROR, - context={"host_ids": host_ids} + context={"host_ids": host_ids}, ) - + # Detect common characteristics os_families = {} os_versions = {} architectures = {} - + for host in hosts: if not host.os_family: self._detect_host_os_info(host) - + if host.os_family: os_families[host.os_family] = os_families.get(host.os_family, 0) + 1 if host.os_version: os_versions[host.os_version] = os_versions.get(host.os_version, 0) + 1 if host.architecture: architectures[host.architecture] = architectures.get(host.architecture, 0) + 1 - + # Determine group characteristics result = { "hosts_analyzed": len(hosts), "characteristics": { "os_families": os_families, "os_versions": os_versions, - "architectures": architectures + "architectures": architectures, }, - "recommendations": {} + "recommendations": {}, } - + # Check if hosts are homogeneous if len(os_families) == 1: # Homogeneous OS family os_family = list(os_families.keys())[0] result["recommendations"]["os_family"] = os_family - + # Check version pattern if os_versions: versions = list(os_versions.keys()) @@ -638,45 +644,49 @@ def create_smart_group_from_hosts( common_prefix = self._find_common_version_pattern(versions) if common_prefix: result["recommendations"]["os_version_pattern"] = f"{common_prefix}*" - + # Recommend SCAP content - scap_content = self._find_matching_scap_content(os_family, result["recommendations"].get("os_version_pattern")) + scap_content = self._find_matching_scap_content( + os_family, result["recommendations"].get("os_version_pattern") + ) if scap_content: result["recommendations"]["scap_content"] = { "id": scap_content.id, "name": scap_content.name, - "compliance_framework": scap_content.compliance_framework + "compliance_framework": scap_content.compliance_framework, } else: # Mixed OS families result["warnings"] = [ f"Mixed OS families detected: {', '.join(os_families.keys())}", - "Consider creating separate groups for each OS family" + "Consider creating separate groups for each OS family", ] - + # Suggest splitting into multiple groups result["split_suggestions"] = [] for os_family in os_families.keys(): family_hosts = [h for h in hosts if h.os_family == os_family] - result["split_suggestions"].append({ - "os_family": os_family, - "host_count": len(family_hosts), - "suggested_name": f"{group_name} - {os_family}" - }) - + result["split_suggestions"].append( + { + "os_family": os_family, + "host_count": len(family_hosts), + "suggested_name": f"{group_name} - {os_family}", + } + ) + return result - + def _find_common_version_pattern(self, versions: List[str]) -> Optional[str]: """Find common version pattern from a list of versions""" if not versions: return None - + # Split versions into components version_parts = [] for version in versions: parts = version.split(".") version_parts.append(parts) - + # Find common prefix common_parts = [] for i in range(min(len(parts) for parts in version_parts)): @@ -685,32 +695,26 @@ def _find_common_version_pattern(self, versions: List[str]) -> Optional[str]: common_parts.append(part_values[0]) else: break - + return ".".join(common_parts) if common_parts else None - + def _find_matching_scap_content( - self, - os_family: str, - os_version_pattern: Optional[str] = None + self, os_family: str, os_version_pattern: Optional[str] = None ) -> Optional[ScapContent]: """Find SCAP content matching OS characteristics""" - query = self.db.query(ScapContent).filter( - ScapContent.os_family == os_family - ) - + query = self.db.query(ScapContent).filter(ScapContent.os_family == os_family) + if os_version_pattern: # Try to find exact version match first version = os_version_pattern.replace("*", "") - content = query.filter( - ScapContent.os_version.like(f"{version}%") - ).first() - + content = query.filter(ScapContent.os_version.like(f"{version}%")).first() + if content: return content - + # Return any content for the OS family return query.first() - + def get_group_compatibility_report(self, group_id: int) -> Dict[str, Any]: """Generate a comprehensive compatibility report for a group""" group = self.db.query(HostGroup).filter(HostGroup.id == group_id).first() @@ -719,14 +723,16 @@ def get_group_compatibility_report(self, group_id: int) -> Dict[str, Any]: message=f"Group {group_id} not found", category=ErrorCategory.NOT_FOUND, severity=ErrorSeverity.ERROR, - context={"group_id": group_id} + context={"group_id": group_id}, ) - + # Get all hosts in the group - memberships = self.db.query(HostGroupMembership).filter( - HostGroupMembership.group_id == group_id - ).all() - + memberships = ( + self.db.query(HostGroupMembership) + .filter(HostGroupMembership.group_id == group_id) + .all() + ) + report = { "group": { "id": group.id, @@ -734,27 +740,27 @@ def get_group_compatibility_report(self, group_id: int) -> Dict[str, Any]: "description": group.description, "os_family": group.os_family, "os_version_pattern": group.os_version_pattern, - "compliance_framework": group.compliance_framework + "compliance_framework": group.compliance_framework, }, "statistics": { "total_hosts": len(memberships), "fully_compatible": 0, "partially_compatible": 0, - "incompatible": 0 + "incompatible": 0, }, "hosts": [], "issues": [], - "recommendations": [] + "recommendations": [], } - + # Check each host for membership in memberships: host = self.db.query(Host).filter(Host.id == membership.host_id).first() if not host: continue - + compatibility = self._check_host_compatibility(host, group) - + host_report = { "id": str(host.id), "hostname": host.hostname, @@ -762,11 +768,11 @@ def get_group_compatibility_report(self, group_id: int) -> Dict[str, Any]: "compatibility_score": compatibility["score"], "is_compatible": compatibility["is_compatible"], "issues": compatibility.get("reasons", []), - "warnings": compatibility.get("warnings", []) + "warnings": compatibility.get("warnings", []), } - + report["hosts"].append(host_report) - + # Update statistics if compatibility["score"] >= 95: report["statistics"]["fully_compatible"] += 1 @@ -774,23 +780,27 @@ def get_group_compatibility_report(self, group_id: int) -> Dict[str, Any]: report["statistics"]["partially_compatible"] += 1 else: report["statistics"]["incompatible"] += 1 - + # Collect issues report["issues"].extend(compatibility.get("reasons", [])) - + # Generate recommendations if report["statistics"]["incompatible"] > 0: - report["recommendations"].append({ - "type": "warning", - "message": f"{report['statistics']['incompatible']} hosts are incompatible with this group", - "action": "Review group requirements or remove incompatible hosts" - }) - + report["recommendations"].append( + { + "type": "warning", + "message": f"{report['statistics']['incompatible']} hosts are incompatible with this group", + "action": "Review group requirements or remove incompatible hosts", + } + ) + if report["statistics"]["partially_compatible"] > report["statistics"]["fully_compatible"]: - report["recommendations"].append({ - "type": "info", - "message": "Most hosts are only partially compatible", - "action": "Consider relaxing group requirements or creating sub-groups" - }) - - return report \ No newline at end of file + report["recommendations"].append( + { + "type": "info", + "message": "Most hosts are only partially compatible", + "action": "Consider relaxing group requirements or creating sub-groups", + } + ) + + return report diff --git a/backend/app/services/host_monitor.py b/backend/app/services/host_monitor.py index cde355a7..58d21b7e 100644 --- a/backend/app/services/host_monitor.py +++ b/backend/app/services/host_monitor.py @@ -2,6 +2,7 @@ Host Monitoring Service Provides various methods to check host availability and status """ + import asyncio import base64 import logging @@ -18,37 +19,41 @@ logger = logging.getLogger(__name__) + class HostMonitor: def __init__(self): self.ssh_timeout = 10 # seconds - self.ping_timeout = 5 # seconds - + self.ping_timeout = 5 # seconds + async def ping_host(self, ip_address: str) -> bool: """ Simple ICMP ping to check basic connectivity with fallback to socket test """ try: # First try actual ping command - cmd = ['ping', '-c', '1', '-W', str(self.ping_timeout), ip_address] - result = subprocess.run(cmd, capture_output=True, text=True, timeout=self.ping_timeout + 2) + cmd = ["ping", "-c", "1", "-W", str(self.ping_timeout), ip_address] + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=self.ping_timeout + 2 + ) if result.returncode == 0: return True - + except FileNotFoundError: logger.debug(f"Ping command not found, using socket fallback for {ip_address}") except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e: logger.debug(f"Ping command failed for {ip_address}: {e}") - + # Fallback to socket connection test try: # Use socket connection test as ping alternative import socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(self.ping_timeout) - + # Try to connect to common ports ports_to_try = [22, 80, 443, 21, 23, 25] - + for port in ports_to_try: try: result = sock.connect_ex((ip_address, port)) @@ -61,14 +66,14 @@ async def ping_host(self, ip_address: str) -> bool: sock.settimeout(self.ping_timeout) except: continue - + sock.close() return False - + except Exception as e: logger.debug(f"Socket connectivity test failed for {ip_address}: {e}") return False - + async def check_port_connectivity(self, ip_address: str, port: int) -> bool: """ Check if a specific port is reachable @@ -82,12 +87,16 @@ async def check_port_connectivity(self, ip_address: str, port: int) -> bool: except Exception as e: logger.debug(f"Port check failed for {ip_address}:{port}: {e}") return False - - async def check_ssh_connectivity(self, ip_address: str, port: int = 22, - username: Optional[str] = None, - key_path: Optional[str] = None, - private_key_content: Optional[str] = None, - password: Optional[str] = None) -> Tuple[bool, Optional[str]]: + + async def check_ssh_connectivity( + self, + ip_address: str, + port: int = 22, + username: Optional[str] = None, + key_path: Optional[str] = None, + private_key_content: Optional[str] = None, + password: Optional[str] = None, + ) -> Tuple[bool, Optional[str]]: """ Test SSH connectivity to determine if host is accessible for scanning Returns (is_connected, error_message) @@ -99,38 +108,44 @@ async def check_ssh_connectivity(self, ip_address: str, port: int = 22, # Load system and user host keys for validation try: ssh.load_system_host_keys() - ssh.load_host_keys(os.path.expanduser('~/.ssh/known_hosts')) + ssh.load_host_keys(os.path.expanduser("~/.ssh/known_hosts")) except FileNotFoundError: - logger.warning("No known_hosts files found - SSH connections may fail without proper host key management") - + logger.warning( + "No known_hosts files found - SSH connections may fail without proper host key management" + ) + connect_kwargs = { - 'hostname': ip_address, - 'port': port, - 'timeout': self.ssh_timeout, - 'banner_timeout': self.ssh_timeout + "hostname": ip_address, + "port": port, + "timeout": self.ssh_timeout, + "banner_timeout": self.ssh_timeout, } - + if username: - connect_kwargs['username'] = username - + connect_kwargs["username"] = username + if key_path: - connect_kwargs['key_filename'] = key_path + connect_kwargs["key_filename"] = key_path elif private_key_content: # Load private key from content string using new utility try: # Validate key first validation_result = validate_ssh_key(private_key_content) if not validation_result.is_valid: - logger.error(f"Invalid SSH key for {ip_address}: {validation_result.error_message}") + logger.error( + f"Invalid SSH key for {ip_address}: {validation_result.error_message}" + ) return False, f"Invalid SSH key: {validation_result.error_message}" - + # Log any warnings if validation_result.warnings: - logger.warning(f"SSH key warnings for {ip_address}: {'; '.join(validation_result.warnings)}") - + logger.warning( + f"SSH key warnings for {ip_address}: {'; '.join(validation_result.warnings)}" + ) + # Parse key using unified parser private_key = parse_ssh_key(private_key_content) - connect_kwargs['pkey'] = private_key + connect_kwargs["pkey"] = private_key except SSHKeyError as e: logger.error(f"SSH key parsing failed for {ip_address}: {e}") return False, f"SSH key error: {str(e)}" @@ -138,23 +153,25 @@ async def check_ssh_connectivity(self, ip_address: str, port: int = 22, logger.error(f"Failed to load private key for {ip_address}: {e}") return False, f"Invalid private key: {str(e)}" elif password: - connect_kwargs['password'] = password - + connect_kwargs["password"] = password + ssh.connect(**connect_kwargs) - + # Test basic command execution stdin, stdout, stderr = ssh.exec_command('echo "test"', timeout=5) exit_status = stdout.channel.recv_exit_status() - + ssh.close() - + if exit_status == 0: return True, None else: return False, "SSH command execution failed" - + except paramiko.AuthenticationException: - logger.warning(f"SSH authentication failed for {ip_address} - check credentials in Settings") + logger.warning( + f"SSH authentication failed for {ip_address} - check credentials in Settings" + ) return False, "Authentication failed - verify SSH credentials in Settings" except paramiko.SSHException as e: logger.error(f"SSH connection error to {ip_address}: {e}") @@ -165,7 +182,7 @@ async def check_ssh_connectivity(self, ip_address: str, port: int = 22, except Exception as e: logger.error(f"SSH connection failed to {ip_address}: {e}") return False, f"Connection error: {str(e)}" - + async def get_effective_ssh_credentials(self, host_data: Dict, db) -> Dict: """ Get effective SSH credentials for a host using centralized authentication service. @@ -174,104 +191,132 @@ async def get_effective_ssh_credentials(self, host_data: Dict, db) -> Dict: try: # Use centralized authentication service for all credential resolution from ..services.auth_service import get_auth_service + auth_service = get_auth_service(db) - + # Determine if we should use default credentials or host-specific - host_auth_method = host_data.get('auth_method') - use_default = host_auth_method in ['default', 'system_default'] - target_id = None if use_default else host_data.get('id') - - logger.info(f"Resolving credentials for host monitoring {host_data.get('hostname')}: use_default={use_default}, target_id={target_id}") - + host_auth_method = host_data.get("auth_method") + use_default = host_auth_method in ["default", "system_default"] + target_id = None if use_default else host_data.get("id") + + logger.info( + f"Resolving credentials for host monitoring {host_data.get('hostname')}: use_default={use_default}, target_id={target_id}" + ) + # First, try to get host-specific credentials from the hosts table if not use_default and target_id: from sqlalchemy import text - result = db.execute(text(""" + + result = db.execute( + text( + """ SELECT encrypted_credentials, username, auth_method FROM hosts WHERE id = :id AND encrypted_credentials IS NOT NULL - """), {"id": target_id}) - + """ + ), + {"id": target_id}, + ) + row = result.fetchone() if row and row.encrypted_credentials: - logger.info(f"Found host-specific credentials in hosts table for {host_data.get('hostname')}") + logger.info( + f"Found host-specific credentials in hosts table for {host_data.get('hostname')}" + ) # Decrypt the credentials from ..services.crypto import decrypt_credentials import json + try: # Handle memoryview objects from database encrypted_data = row.encrypted_credentials if isinstance(encrypted_data, memoryview): encrypted_data = bytes(encrypted_data) - + decrypted_data = decrypt_credentials(encrypted_data) cred_data = json.loads(decrypted_data) - + credentials = { - 'username': cred_data.get('username', row.username), - 'auth_method': cred_data.get('auth_method', row.auth_method), - 'password': cred_data.get('password'), - 'private_key': cred_data.get('ssh_key'), - 'private_key_passphrase': None, - 'source': 'host_encrypted_credentials' + "username": cred_data.get("username", row.username), + "auth_method": cred_data.get("auth_method", row.auth_method), + "password": cred_data.get("password"), + "private_key": cred_data.get("ssh_key"), + "private_key_passphrase": None, + "source": "host_encrypted_credentials", } - logger.info(f"✅ Decrypted host credentials for {host_data.get('hostname')}") + logger.info( + f"✅ Decrypted host credentials for {host_data.get('hostname')}" + ) return credentials except Exception as e: logger.error(f"Failed to decrypt host credentials: {e}") - + # Try centralized auth service (for system defaults or if host decryption failed) credential_data = auth_service.resolve_credential( - target_id=target_id, - use_default=use_default + target_id=target_id, use_default=use_default ) - + if not credential_data: logger.warning(f"No credentials available for host {host_data.get('hostname')}") - logger.info("Please configure system SSH credentials in Settings to enable remote host monitoring and scanning") + logger.info( + "Please configure system SSH credentials in Settings to enable remote host monitoring and scanning" + ) return None - + # Convert to format expected by host monitoring credentials = { - 'username': credential_data.username, - 'auth_method': credential_data.auth_method.value, - 'password': credential_data.password, - 'private_key': credential_data.private_key, # ✅ Consistent field naming - 'private_key_passphrase': credential_data.private_key_passphrase, - 'source': credential_data.source + "username": credential_data.username, + "auth_method": credential_data.auth_method.value, + "password": credential_data.password, + "private_key": credential_data.private_key, # ✅ Consistent field naming + "private_key_passphrase": credential_data.private_key_passphrase, + "source": credential_data.source, } - - logger.info(f"✅ Resolved {credential_data.source} credentials for host monitoring {host_data.get('hostname')}") + + logger.info( + f"✅ Resolved {credential_data.source} credentials for host monitoring {host_data.get('hostname')}" + ) return credentials - + except Exception as e: - logger.error(f"Failed to resolve credentials for host monitoring {host_data.get('hostname')}: {e}") + logger.error( + f"Failed to resolve credentials for host monitoring {host_data.get('hostname')}: {e}" + ) return None - + def validate_ssh_credentials(self, credentials: Dict) -> Tuple[bool, str]: """ Validate that SSH credentials are configured and not placeholder values Returns (is_valid, error_message) """ if not credentials: - return False, "No SSH credentials available. Please configure system credentials in Settings." - - username = credentials.get('username') - password = credentials.get('password') - private_key = credentials.get('private_key') - auth_method = credentials.get('auth_method', 'password') - + return ( + False, + "No SSH credentials available. Please configure system credentials in Settings.", + ) + + username = credentials.get("username") + password = credentials.get("password") + private_key = credentials.get("private_key") + auth_method = credentials.get("auth_method", "password") + if not username: return False, "SSH username is required. Please update credentials in Settings." - - if auth_method in ['password', 'both']: - if not password or password == 'CHANGE_ME_PLEASE': - return False, "SSH password is required or contains placeholder value. Please update credentials in Settings." - - if auth_method in ['ssh_key', 'both']: - if not private_key or 'CHANGE_ME_PLEASE' in private_key: - return False, "SSH private key is required or contains placeholder value. Please update credentials in Settings." - + + if auth_method in ["password", "both"]: + if not password or password == "CHANGE_ME_PLEASE": + return ( + False, + "SSH password is required or contains placeholder value. Please update credentials in Settings.", + ) + + if auth_method in ["ssh_key", "both"]: + if not private_key or "CHANGE_ME_PLEASE" in private_key: + return ( + False, + "SSH private key is required or contains placeholder value. Please update credentials in Settings.", + ) + return True, "" async def comprehensive_host_check(self, host_data: Dict, db=None) -> Dict: @@ -279,247 +324,306 @@ async def comprehensive_host_check(self, host_data: Dict, db=None) -> Dict: Perform comprehensive host availability check Returns status information """ - ip_address = host_data.get('ip_address') - hostname = host_data.get('hostname') - port = int(host_data.get('port', 22)) - username = host_data.get('username') - - logger.info(f"Starting comprehensive check for {hostname}, db connection: {'available' if db else 'None'}") - + ip_address = host_data.get("ip_address") + hostname = host_data.get("hostname") + port = int(host_data.get("port", 22)) + username = host_data.get("username") + + logger.info( + f"Starting comprehensive check for {hostname}, db connection: {'available' if db else 'None'}" + ) + check_results = { - 'host_id': host_data.get('id'), - 'hostname': hostname, - 'ip_address': ip_address, - 'timestamp': datetime.utcnow().isoformat(), - 'ping_success': False, - 'port_open': False, - 'ssh_accessible': False, - 'status': 'offline', - 'error_message': None, - 'response_time_ms': None, - 'ssh_credentials_source': None, - 'ssh_username': None, - 'credential_details': None + "host_id": host_data.get("id"), + "hostname": hostname, + "ip_address": ip_address, + "timestamp": datetime.utcnow().isoformat(), + "ping_success": False, + "port_open": False, + "ssh_accessible": False, + "status": "offline", + "error_message": None, + "response_time_ms": None, + "ssh_credentials_source": None, + "ssh_username": None, + "credential_details": None, } - + start_time = time.time() - + try: # Step 1: Connectivity test (ping alternative) logger.info(f"Checking connectivity for {hostname} ({ip_address})") - check_results['ping_success'] = await self.ping_host(ip_address) - + check_results["ping_success"] = await self.ping_host(ip_address) + # Step 2: Port connectivity logger.info(f"Checking port {port} connectivity for {hostname}") - check_results['port_open'] = await self.check_port_connectivity(ip_address, port) - + check_results["port_open"] = await self.check_port_connectivity(ip_address, port) + # Step 3: SSH connectivity (with credentials inheritance) ssh_credentials = None if db: - logger.info(f"Database connection available, looking up SSH credentials for {hostname}") + logger.info( + f"Database connection available, looking up SSH credentials for {hostname}" + ) ssh_credentials = await self.get_effective_ssh_credentials(host_data, db) else: - logger.warning(f"No database connection available for SSH credential lookup for {hostname}") - + logger.warning( + f"No database connection available for SSH credential lookup for {hostname}" + ) + if ssh_credentials: # Validate credentials before attempting connection is_valid, validation_error = self.validate_ssh_credentials(ssh_credentials) - - username = ssh_credentials['username'] - password = ssh_credentials.get('password') - private_key = ssh_credentials.get('private_key') - source = ssh_credentials.get('source', 'unknown') - auth_method = ssh_credentials.get('auth_method', 'unknown') - + + username = ssh_credentials["username"] + password = ssh_credentials.get("password") + private_key = ssh_credentials.get("private_key") + source = ssh_credentials.get("source", "unknown") + auth_method = ssh_credentials.get("auth_method", "unknown") + # Store credential details for response - check_results['ssh_credentials_source'] = source - check_results['ssh_username'] = username - + check_results["ssh_credentials_source"] = source + check_results["ssh_username"] = username + if not is_valid: - check_results['ssh_accessible'] = False - check_results['credential_details'] = f"❌ {validation_error}" - check_results['error_message'] = validation_error - logger.warning(f"SSH credentials validation failed for {hostname}: {validation_error}") + check_results["ssh_accessible"] = False + check_results["credential_details"] = f"❌ {validation_error}" + check_results["error_message"] = validation_error + logger.warning( + f"SSH credentials validation failed for {hostname}: {validation_error}" + ) else: - check_results['credential_details'] = f"Using {source} credentials (user: {username}, method: {auth_method})" - - logger.info(f"Checking SSH connectivity for {hostname} using {source} credentials (user: {username}, method: {auth_method})") - + check_results["credential_details"] = ( + f"Using {source} credentials (user: {username}, method: {auth_method})" + ) + + logger.info( + f"Checking SSH connectivity for {hostname} using {source} credentials (user: {username}, method: {auth_method})" + ) + # Try SSH connection with validated credentials ssh_success, ssh_error = await self.check_ssh_connectivity( ip_address, port, username, None, private_key, password ) - check_results['ssh_accessible'] = ssh_success - + check_results["ssh_accessible"] = ssh_success + if ssh_success: - check_results['credential_details'] += " - ✅ SSH authentication successful" - logger.info(f"SSH authentication successful for {hostname} using {source} credentials") + check_results["credential_details"] += " - ✅ SSH authentication successful" + logger.info( + f"SSH authentication successful for {hostname} using {source} credentials" + ) else: - check_results['credential_details'] += f" - ❌ SSH authentication failed: {ssh_error}" - check_results['error_message'] = f"SSH authentication failed with {source} credentials: {ssh_error}" - logger.warning(f"SSH authentication failed for {hostname} using {source} credentials: {ssh_error}") - + check_results[ + "credential_details" + ] += f" - ❌ SSH authentication failed: {ssh_error}" + check_results["error_message"] = ( + f"SSH authentication failed with {source} credentials: {ssh_error}" + ) + logger.warning( + f"SSH authentication failed for {hostname} using {source} credentials: {ssh_error}" + ) + else: - check_results['credential_details'] = "❌ No SSH credentials available (neither host-specific nor system default)" - check_results['error_message'] = "No SSH credentials configured. Please configure system credentials in Settings to enable SSH operations." - logger.warning(f"No SSH credentials available for {hostname} - configure in Settings") - logger.info(f"No SSH credentials available for {hostname} (neither host-specific nor system default)") - + check_results["credential_details"] = ( + "❌ No SSH credentials available (neither host-specific nor system default)" + ) + check_results["error_message"] = ( + "No SSH credentials configured. Please configure system credentials in Settings to enable SSH operations." + ) + logger.warning( + f"No SSH credentials available for {hostname} - configure in Settings" + ) + logger.info( + f"No SSH credentials available for {hostname} (neither host-specific nor system default)" + ) + # Determine overall status - if check_results['ssh_accessible']: - check_results['status'] = 'online' + if check_results["ssh_accessible"]: + check_results["status"] = "online" logger.info(f"Host {hostname} is ONLINE (SSH accessible)") - elif check_results['port_open']: - check_results['status'] = 'reachable' # Port open but can't SSH + elif check_results["port_open"]: + check_results["status"] = "reachable" # Port open but can't SSH logger.info(f"Host {hostname} is REACHABLE (port open, SSH issues)") - elif check_results['ping_success']: - check_results['status'] = 'ping_only' # Responds to connectivity test but port closed + elif check_results["ping_success"]: + check_results["status"] = ( + "ping_only" # Responds to connectivity test but port closed + ) logger.info(f"Host {hostname} responds to connectivity test but port {port} closed") else: - check_results['status'] = 'offline' - check_results['error_message'] = 'Host unreachable - no response on any tested ports' + check_results["status"] = "offline" + check_results["error_message"] = ( + "Host unreachable - no response on any tested ports" + ) logger.info(f"Host {hostname} is OFFLINE (unreachable)") - + # Calculate response time end_time = time.time() - check_results['response_time_ms'] = int((end_time - start_time) * 1000) - + check_results["response_time_ms"] = int((end_time - start_time) * 1000) + except Exception as e: logger.error(f"Error checking host {hostname}: {e}") - check_results['error_message'] = f"Monitoring error: {str(e)}" - check_results['status'] = 'error' - + check_results["error_message"] = f"Monitoring error: {str(e)}" + check_results["status"] = "error" + return check_results - - async def update_host_status(self, db: Session, host_id: str, status: str, - last_seen: Optional[datetime] = None, - error_message: Optional[str] = None) -> bool: + + async def update_host_status( + self, + db: Session, + host_id: str, + status: str, + last_seen: Optional[datetime] = None, + error_message: Optional[str] = None, + ) -> bool: """ Update host status in database with last check timestamp """ try: update_data = { - 'id': host_id, - 'status': status, - 'updated_at': datetime.utcnow(), - 'last_check': datetime.utcnow() + "id": host_id, + "status": status, + "updated_at": datetime.utcnow(), + "last_check": datetime.utcnow(), } - + query = """ UPDATE hosts SET status = :status, updated_at = :updated_at, last_check = :last_check WHERE id = :id """ - + db.execute(text(query), update_data) db.commit() - + logger.info(f"Updated host {host_id} status to {status} with last_check timestamp") return True - + except Exception as e: logger.error(f"Failed to update host status: {e}") db.rollback() return False - + async def monitor_all_hosts(self, db: Session) -> List[Dict]: """ Monitor all hosts in the database """ try: # Get all active hosts - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, hostname, ip_address, port, username, auth_method, status, last_check FROM hosts WHERE is_active = true ORDER BY hostname - """)) - + """ + ) + ) + hosts = [] for row in result: - hosts.append({ - 'id': str(row.id), - 'hostname': row.hostname, - 'ip_address': str(row.ip_address), - 'port': row.port or 22, - 'username': row.username, - 'auth_method': row.auth_method, - 'current_status': row.status, - 'last_check': row.last_check - }) - + hosts.append( + { + "id": str(row.id), + "hostname": row.hostname, + "ip_address": str(row.ip_address), + "port": row.port or 22, + "username": row.username, + "auth_method": row.auth_method, + "current_status": row.status, + "last_check": row.last_check, + } + ) + # Check each host check_results = [] for host in hosts: result = await self.comprehensive_host_check(host, db) check_results.append(result) - + # Update database if status changed - if result['status'] != host['current_status']: + if result["status"] != host["current_status"]: # Send alert before updating database - await self.send_status_change_alerts(db, host, host['current_status'], result['status']) - + await self.send_status_change_alerts( + db, host, host["current_status"], result["status"] + ) + await self.update_host_status( - db, host['id'], result['status'], - datetime.utcnow() if result['status'] == 'online' else None + db, + host["id"], + result["status"], + datetime.utcnow() if result["status"] == "online" else None, ) - + return check_results - + except Exception as e: logger.error(f"Error monitoring hosts: {e}") return [] - + async def get_alert_recipients(self, db: Session, alert_type: str) -> List[str]: """Get email recipients for a specific alert type""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT email_addresses FROM alert_settings WHERE alert_type = :alert_type AND enabled = true AND email_enabled = true AND email_addresses IS NOT NULL - """), {"alert_type": alert_type}) - + """ + ), + {"alert_type": alert_type}, + ) + recipients = [] for row in result: if row.email_addresses: recipients.extend(row.email_addresses) - + return list(set(recipients)) # Remove duplicates - + except Exception as e: logger.error(f"Error getting alert recipients: {e}") return [] - - async def send_status_change_alerts(self, db: Session, host: Dict, old_status: str, new_status: str): + + async def send_status_change_alerts( + self, db: Session, host: Dict, old_status: str, new_status: str + ): """Send email alerts when host status changes""" try: - hostname = host.get('hostname', 'Unknown') - ip_address = host.get('ip_address', 'Unknown') - last_check = host.get('last_check') or datetime.utcnow() - + hostname = host.get("hostname", "Unknown") + ip_address = host.get("ip_address", "Unknown") + last_check = host.get("last_check") or datetime.utcnow() + # Host went offline - if old_status == 'online' and new_status in ['offline', 'error']: - recipients = await self.get_alert_recipients(db, 'host_offline') + if old_status == "online" and new_status in ["offline", "error"]: + recipients = await self.get_alert_recipients(db, "host_offline") if recipients: - logger.info(f"Sending offline alert for {hostname} to {len(recipients)} recipients") + logger.info( + f"Sending offline alert for {hostname} to {len(recipients)} recipients" + ) await email_service.send_host_offline_alert( hostname, ip_address, last_check, recipients ) - + # Host came back online - elif old_status in ['offline', 'error'] and new_status == 'online': - recipients = await self.get_alert_recipients(db, 'host_online') + elif old_status in ["offline", "error"] and new_status == "online": + recipients = await self.get_alert_recipients(db, "host_online") if recipients: - logger.info(f"Sending online alert for {hostname} to {len(recipients)} recipients") + logger.info( + f"Sending online alert for {hostname} to {len(recipients)} recipients" + ) await email_service.send_host_online_alert( hostname, ip_address, last_check, recipients ) - + except Exception as e: logger.error(f"Error sending status change alerts: {e}") + # Global monitor instance -host_monitor = HostMonitor() \ No newline at end of file +host_monitor = HostMonitor() diff --git a/backend/app/services/http_client.py b/backend/app/services/http_client.py index 9213d6b2..4c6e7355 100644 --- a/backend/app/services/http_client.py +++ b/backend/app/services/http_client.py @@ -2,6 +2,7 @@ HTTP Client Infrastructure with Retry Logic Centralized HTTP client for OpenWatch with exponential backoff, circuit breaker, and monitoring """ + import asyncio import time from typing import Dict, Any, Optional, List, Union @@ -17,29 +18,33 @@ class CircuitBreakerState(str, Enum): """Circuit breaker states""" - CLOSED = "closed" # Normal operation - OPEN = "open" # Failing, reject requests + + CLOSED = "closed" # Normal operation + OPEN = "open" # Failing, reject requests HALF_OPEN = "half_open" # Testing if service recovered class RetryPolicy(BaseModel): """Retry policy configuration""" + max_retries: int = 3 base_delay: float = 1.0 # seconds max_delay: float = 60.0 # seconds exponential_base: float = 2.0 jitter: bool = True - - + + class CircuitBreakerConfig(BaseModel): """Circuit breaker configuration""" + failure_threshold: int = 5 # Failures before opening recovery_timeout: int = 60 # Seconds before trying half-open - success_threshold: int = 2 # Successes needed to close - + success_threshold: int = 2 # Successes needed to close + class HTTPClientStats(BaseModel): """HTTP client statistics""" + total_requests: int = 0 successful_requests: int = 0 failed_requests: int = 0 @@ -48,26 +53,28 @@ class HTTPClientStats(BaseModel): last_failure: Optional[datetime] = None consecutive_failures: int = 0 consecutive_successes: int = 0 - + class CircuitBreaker: """Circuit breaker implementation for HTTP requests""" - + def __init__(self, config: CircuitBreakerConfig): self.config = config self.state = CircuitBreakerState.CLOSED self.failure_count = 0 self.success_count = 0 self.last_failure_time = None - + def can_execute(self) -> bool: """Check if request can be executed""" if self.state == CircuitBreakerState.CLOSED: return True elif self.state == CircuitBreakerState.OPEN: # Check if enough time has passed to try half-open - if (self.last_failure_time and - time.time() - self.last_failure_time >= self.config.recovery_timeout): + if ( + self.last_failure_time + and time.time() - self.last_failure_time >= self.config.recovery_timeout + ): self.state = CircuitBreakerState.HALF_OPEN self.success_count = 0 logger.info("Circuit breaker transitioning to half-open") @@ -75,7 +82,7 @@ def can_execute(self) -> bool: return False else: # HALF_OPEN return True - + def record_success(self): """Record successful request""" if self.state == CircuitBreakerState.HALF_OPEN: @@ -86,14 +93,16 @@ def record_success(self): logger.info("Circuit breaker closed after successful recovery") elif self.state == CircuitBreakerState.CLOSED: self.failure_count = 0 - + def record_failure(self): """Record failed request""" self.failure_count += 1 self.last_failure_time = time.time() - - if (self.state == CircuitBreakerState.CLOSED and - self.failure_count >= self.config.failure_threshold): + + if ( + self.state == CircuitBreakerState.CLOSED + and self.failure_count >= self.config.failure_threshold + ): self.state = CircuitBreakerState.OPEN logger.warning(f"Circuit breaker opened after {self.failure_count} failures") elif self.state == CircuitBreakerState.HALF_OPEN: @@ -103,45 +112,44 @@ def record_failure(self): class HttpClient: """Enhanced HTTP client with retry logic, circuit breaker, and monitoring""" - + def __init__( self, retry_policy: Optional[RetryPolicy] = None, circuit_breaker_config: Optional[CircuitBreakerConfig] = None, timeout: float = 30.0, - user_agent: str = "OpenWatch-HttpClient/1.0" + user_agent: str = "OpenWatch-HttpClient/1.0", ): self.retry_policy = retry_policy or RetryPolicy() self.circuit_breaker = CircuitBreaker(circuit_breaker_config or CircuitBreakerConfig()) self.timeout = httpx.Timeout(timeout) self.user_agent = user_agent self.stats = HTTPClientStats() - + # Create httpx client self.client = httpx.AsyncClient( - timeout=self.timeout, - headers={"User-Agent": self.user_agent}, - follow_redirects=True + timeout=self.timeout, headers={"User-Agent": self.user_agent}, follow_redirects=True ) - + async def close(self): """Close the HTTP client""" await self.client.aclose() - + def _calculate_delay(self, attempt: int) -> float: """Calculate delay for exponential backoff""" delay = min( - self.retry_policy.base_delay * (self.retry_policy.exponential_base ** attempt), - self.retry_policy.max_delay + self.retry_policy.base_delay * (self.retry_policy.exponential_base**attempt), + self.retry_policy.max_delay, ) - + # Add jitter to prevent thundering herd if self.retry_policy.jitter: import random - delay *= (0.5 + random.random() * 0.5) - + + delay *= 0.5 + random.random() * 0.5 + return delay - + def _is_retryable_error(self, exception: Exception) -> bool: """Determine if an error is retryable""" if isinstance(exception, httpx.TimeoutException): @@ -152,26 +160,20 @@ def _is_retryable_error(self, exception: Exception) -> bool: # Retry on server errors (5xx) but not client errors (4xx) return 500 <= exception.response.status_code < 600 return False - - async def _execute_request( - self, - method: str, - url: str, - **kwargs - ) -> httpx.Response: + + async def _execute_request(self, method: str, url: str, **kwargs) -> httpx.Response: """Execute HTTP request with retry logic and circuit breaker""" - + # Check circuit breaker if not self.circuit_breaker.can_execute(): self.stats.failed_requests += 1 raise httpx.ConnectError( - "Circuit breaker is open - service unavailable", - request=httpx.Request(method, url) + "Circuit breaker is open - service unavailable", request=httpx.Request(method, url) ) - + self.stats.total_requests += 1 last_exception = None - + for attempt in range(self.retry_policy.max_retries + 1): try: # Log request @@ -180,42 +182,42 @@ async def _execute_request( method=method, url=url, attempt=attempt + 1, - max_attempts=self.retry_policy.max_retries + 1 + max_attempts=self.retry_policy.max_retries + 1, ) - + # Execute request response = await self.client.request(method, url, **kwargs) - + # Check for HTTP errors response.raise_for_status() - + # Success - update stats and circuit breaker self.stats.successful_requests += 1 self.circuit_breaker.record_success() - + logger.debug( "HTTP request successful", method=method, url=url, status_code=response.status_code, - attempt=attempt + 1 + attempt=attempt + 1, ) - + return response - + except Exception as e: last_exception = e self.stats.total_retries += 1 - + logger.warning( "HTTP request failed", method=method, url=url, attempt=attempt + 1, error=str(e), - error_type=type(e).__name__ + error_type=type(e).__name__, ) - + # Check if we should retry if attempt < self.retry_policy.max_retries and self._is_retryable_error(e): delay = self._calculate_delay(attempt) @@ -225,103 +227,87 @@ async def _execute_request( else: # No more retries or non-retryable error break - + # All retries exhausted - record failure self.stats.failed_requests += 1 self.circuit_breaker.record_failure() self.stats.last_failure = datetime.utcnow() self.stats.consecutive_failures += 1 self.stats.consecutive_successes = 0 - + logger.error( "HTTP request failed after all retries", method=method, url=url, total_attempts=self.retry_policy.max_retries + 1, - final_error=str(last_exception) + final_error=str(last_exception), ) - + raise last_exception - + async def get(self, url: str, **kwargs) -> httpx.Response: """Execute GET request""" return await self._execute_request("GET", url, **kwargs) - + async def post(self, url: str, **kwargs) -> httpx.Response: """Execute POST request""" return await self._execute_request("POST", url, **kwargs) - + async def put(self, url: str, **kwargs) -> httpx.Response: """Execute PUT request""" return await self._execute_request("PUT", url, **kwargs) - + async def delete(self, url: str, **kwargs) -> httpx.Response: """Execute DELETE request""" return await self._execute_request("DELETE", url, **kwargs) - + async def patch(self, url: str, **kwargs) -> httpx.Response: """Execute PATCH request""" return await self._execute_request("PATCH", url, **kwargs) - + def get_stats(self) -> Dict[str, Any]: """Get client statistics""" self.stats.circuit_breaker_state = self.circuit_breaker.state return self.stats.dict() - + def reset_stats(self): """Reset client statistics""" self.stats = HTTPClientStats() - + class WebhookHttpClient(HttpClient): """Specialized HTTP client for webhook delivery""" - + def __init__(self): # Webhook-specific configuration retry_policy = RetryPolicy( - max_retries=3, - base_delay=1.0, - max_delay=30.0, - exponential_base=2.0, - jitter=True + max_retries=3, base_delay=1.0, max_delay=30.0, exponential_base=2.0, jitter=True ) - + circuit_breaker_config = CircuitBreakerConfig( - failure_threshold=5, - recovery_timeout=300, # 5 minutes for webhooks - success_threshold=2 + failure_threshold=5, recovery_timeout=300, success_threshold=2 # 5 minutes for webhooks ) - + super().__init__( retry_policy=retry_policy, circuit_breaker_config=circuit_breaker_config, timeout=30.0, - user_agent="OpenWatch-Webhook/1.0" + user_agent="OpenWatch-Webhook/1.0", ) - + async def deliver_webhook( - self, - url: str, - payload: Dict[str, Any], - headers: Dict[str, str] + self, url: str, payload: Dict[str, Any], headers: Dict[str, str] ) -> httpx.Response: """Deliver webhook with specialized handling""" import json - + # Prepare webhook headers - webhook_headers = { - "Content-Type": "application/json", - **headers - } - + webhook_headers = {"Content-Type": "application/json", **headers} + # Convert payload to JSON - json_payload = json.dumps(payload, separators=(',', ':')) - - return await self.post( - url, - headers=webhook_headers, - content=json_payload - ) + json_payload = json.dumps(payload, separators=(",", ":")) + + return await self.post(url, headers=webhook_headers, content=json_payload) # Global client instances @@ -348,11 +334,11 @@ async def get_webhook_client() -> WebhookHttpClient: async def close_all_clients(): """Close all HTTP client instances""" global _default_client, _webhook_client - + if _default_client: await _default_client.close() _default_client = None - + if _webhook_client: await _webhook_client.close() - _webhook_client = None \ No newline at end of file + _webhook_client = None diff --git a/backend/app/services/integration_metrics.py b/backend/app/services/integration_metrics.py index 73ce05aa..23e9b1c7 100644 --- a/backend/app/services/integration_metrics.py +++ b/backend/app/services/integration_metrics.py @@ -2,6 +2,7 @@ Integration Performance Metrics Service Collects and tracks performance metrics for OpenWatch-AEGIS integration """ + import time import logging from datetime import datetime, timedelta @@ -14,9 +15,11 @@ logger = logging.getLogger(__name__) + @dataclass class IntegrationMetric: """Individual metric data point""" + timestamp: float metric_type: str operation: str @@ -25,9 +28,11 @@ class IntegrationMetric: success: bool = True error: Optional[str] = None + @dataclass class MetricsSummary: """Aggregated metrics summary""" + total_requests: int successful_requests: int failed_requests: int @@ -37,22 +42,29 @@ class MetricsSummary: p95_duration: float error_rate: float + class IntegrationMetricsCollector: """Collects and manages integration performance metrics""" - + def __init__(self, retention_hours: int = 24, max_metrics: int = 10000): self.retention_hours = retention_hours self.max_metrics = max_metrics self.metrics: deque = deque(maxlen=max_metrics) self.lock = threading.RLock() - + # In-memory counters for quick access self.counters = defaultdict(int) self.timers = defaultdict(list) - - def record_metric(self, metric_type: str, operation: str, value: float, - success: bool = True, error: Optional[str] = None, - labels: Optional[Dict[str, str]] = None): + + def record_metric( + self, + metric_type: str, + operation: str, + value: float, + success: bool = True, + error: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + ): """Record a new metric""" with self.lock: metric = IntegrationMetric( @@ -62,34 +74,34 @@ def record_metric(self, metric_type: str, operation: str, value: float, value=value, labels=labels or {}, success=success, - error=error + error=error, ) - + self.metrics.append(metric) - + # Update counters counter_key = f"{metric_type}.{operation}" self.counters[f"{counter_key}.total"] += 1 - + if success: self.counters[f"{counter_key}.success"] += 1 else: self.counters[f"{counter_key}.error"] += 1 - + # Store timing data for percentile calculations if metric_type == "duration": self.timers[counter_key].append(value) # Keep only last 1000 values for percentile calculations if len(self.timers[counter_key]) > 1000: self.timers[counter_key] = self.timers[counter_key][-1000:] - + @contextmanager def time_operation(self, operation: str, labels: Optional[Dict[str, str]] = None): """Context manager to time operations""" start_time = time.time() error = None success = True - + try: yield except Exception as e: @@ -104,49 +116,50 @@ def time_operation(self, operation: str, labels: Optional[Dict[str, str]] = None value=duration, success=success, error=error, - labels=labels + labels=labels, ) - + def cleanup_old_metrics(self): """Remove metrics older than retention period""" with self.lock: cutoff_time = time.time() - (self.retention_hours * 3600) - + # Remove old metrics while self.metrics and self.metrics[0].timestamp < cutoff_time: self.metrics.popleft() - - def get_metrics_summary(self, operation: Optional[str] = None, - hours: int = 1) -> Dict[str, MetricsSummary]: + + def get_metrics_summary( + self, operation: Optional[str] = None, hours: int = 1 + ) -> Dict[str, MetricsSummary]: """Get aggregated metrics summary""" with self.lock: self.cleanup_old_metrics() - + cutoff_time = time.time() - (hours * 3600) relevant_metrics = [m for m in self.metrics if m.timestamp >= cutoff_time] - + if operation: relevant_metrics = [m for m in relevant_metrics if m.operation == operation] - + # Group by operation grouped_metrics = defaultdict(list) for metric in relevant_metrics: if metric.metric_type == "duration": grouped_metrics[metric.operation].append(metric) - + summaries = {} for op, metrics_list in grouped_metrics.items(): if not metrics_list: continue - + durations = [m.value for m in metrics_list] successful = [m for m in metrics_list if m.success] failed = [m for m in metrics_list if not m.success] - + if durations: durations.sort() p95_index = int(len(durations) * 0.95) - + summaries[op] = MetricsSummary( total_requests=len(metrics_list), successful_requests=len(successful), @@ -154,52 +167,52 @@ def get_metrics_summary(self, operation: Optional[str] = None, average_duration=sum(durations) / len(durations), min_duration=min(durations), max_duration=max(durations), - p95_duration=durations[p95_index] if p95_index < len(durations) else max(durations), - error_rate=(len(failed) / len(metrics_list)) * 100 + p95_duration=( + durations[p95_index] if p95_index < len(durations) else max(durations) + ), + error_rate=(len(failed) / len(metrics_list)) * 100, ) - + return summaries - + def get_current_stats(self) -> Dict[str, Any]: """Get current performance statistics""" with self.lock: self.cleanup_old_metrics() - + stats = { "total_metrics": len(self.metrics), "counters": dict(self.counters), "recent_errors": [], - "top_operations": {} + "top_operations": {}, } - + # Recent errors (last hour) cutoff_time = time.time() - 3600 recent_errors = [ { "timestamp": datetime.fromtimestamp(m.timestamp).isoformat(), "operation": m.operation, - "error": m.error + "error": m.error, } for m in self.metrics if m.timestamp >= cutoff_time and not m.success and m.error ] - + stats["recent_errors"] = recent_errors[-10:] # Last 10 errors - + # Top operations by volume operation_counts = defaultdict(int) for metric in self.metrics: if metric.timestamp >= cutoff_time: operation_counts[metric.operation] += 1 - - stats["top_operations"] = dict(sorted( - operation_counts.items(), - key=lambda x: x[1], - reverse=True - )[:10]) - + + stats["top_operations"] = dict( + sorted(operation_counts.items(), key=lambda x: x[1], reverse=True)[:10] + ) + return stats - + def export_metrics(self, format: str = "json") -> str: """Export metrics in specified format""" if format == "json": @@ -208,25 +221,31 @@ def export_metrics(self, format: str = "json") -> str: # Basic Prometheus format lines = [] summaries = self.get_metrics_summary() - + for operation, summary in summaries.items(): safe_op = operation.replace("-", "_").replace(".", "_") lines.append(f"# HELP integration_{safe_op}_duration_seconds Operation duration") lines.append(f"# TYPE integration_{safe_op}_duration_seconds histogram") - lines.append(f"integration_{safe_op}_duration_seconds_sum {summary.average_duration * summary.total_requests}") - lines.append(f"integration_{safe_op}_duration_seconds_count {summary.total_requests}") - + lines.append( + f"integration_{safe_op}_duration_seconds_sum {summary.average_duration * summary.total_requests}" + ) + lines.append( + f"integration_{safe_op}_duration_seconds_count {summary.total_requests}" + ) + lines.append(f"# HELP integration_{safe_op}_error_rate Error rate percentage") lines.append(f"# TYPE integration_{safe_op}_error_rate gauge") lines.append(f"integration_{safe_op}_error_rate {summary.error_rate}") - + return "\n".join(lines) - + raise ValueError(f"Unsupported export format: {format}") + # Global metrics collector instance metrics_collector = IntegrationMetricsCollector() + # Convenience functions for common operations def record_webhook_delivery(success: bool, duration: float, target_service: str, error: str = None): """Record webhook delivery metrics""" @@ -236,11 +255,13 @@ def record_webhook_delivery(success: bool, duration: float, target_service: str, value=duration, success=success, error=error, - labels={"target_service": target_service} + labels={"target_service": target_service}, ) -def record_api_call(operation: str, success: bool, duration: float, - service: str, error: str = None): + +def record_api_call( + operation: str, success: bool, duration: float, service: str, error: str = None +): """Record API call metrics""" metrics_collector.record_metric( metric_type="duration", @@ -248,11 +269,13 @@ def record_api_call(operation: str, success: bool, duration: float, value=duration, success=success, error=error, - labels={"service": service} + labels={"service": service}, ) -def record_remediation_job(job_id: str, status: str, duration: float, - rules_count: int, success_count: int): + +def record_remediation_job( + job_id: str, status: str, duration: float, rules_count: int, success_count: int +): """Record remediation job metrics""" metrics_collector.record_metric( metric_type="duration", @@ -263,11 +286,18 @@ def record_remediation_job(job_id: str, status: str, duration: float, "job_id": job_id, "status": status, "rules_count": str(rules_count), - "success_count": str(success_count) - } + "success_count": str(success_count), + }, ) + # Context managers for easy timing -time_webhook_delivery = lambda target: metrics_collector.time_operation("webhook_delivery", {"target": target}) -time_api_call = lambda operation, service: metrics_collector.time_operation(f"api_call_{operation}", {"service": service}) -time_remediation = lambda job_id: metrics_collector.time_operation("remediation_execution", {"job_id": job_id}) \ No newline at end of file +time_webhook_delivery = lambda target: metrics_collector.time_operation( + "webhook_delivery", {"target": target} +) +time_api_call = lambda operation, service: metrics_collector.time_operation( + f"api_call_{operation}", {"service": service} +) +time_remediation = lambda job_id: metrics_collector.time_operation( + "remediation_execution", {"job_id": job_id} +) diff --git a/backend/app/services/key_lifecycle_manager.py b/backend/app/services/key_lifecycle_manager.py index a59beadf..3eb6e7a0 100644 --- a/backend/app/services/key_lifecycle_manager.py +++ b/backend/app/services/key_lifecycle_manager.py @@ -23,6 +23,7 @@ class KeyStatus(Enum): """RSA key lifecycle status""" + ACTIVE = "active" PENDING = "pending" DEPRECATED = "deprecated" @@ -32,6 +33,7 @@ class KeyStatus(Enum): @dataclass class RSAKeyMetadata: """RSA key metadata for JWT signing""" + key_id: str key_size: int status: KeyStatus @@ -42,23 +44,23 @@ class RSAKeyMetadata: usage_count: int = 0 last_used: Optional[datetime] = None fingerprint: str = "" - + def to_dict(self) -> Dict: """Convert to dictionary for storage""" data = asdict(self) - data['status'] = self.status.value + data["status"] = self.status.value # Convert datetime objects to ISO strings - for field in ['created_at', 'activated_at', 'deprecated_at', 'expires_at', 'last_used']: + for field in ["created_at", "activated_at", "deprecated_at", "expires_at", "last_used"]: if data[field]: data[field] = data[field].isoformat() return data - + @classmethod - def from_dict(cls, data: Dict) -> 'RSAKeyMetadata': + def from_dict(cls, data: Dict) -> "RSAKeyMetadata": """Create from dictionary""" - data['status'] = KeyStatus(data['status']) + data["status"] = KeyStatus(data["status"]) # Convert ISO strings back to datetime objects - for field in ['created_at', 'activated_at', 'deprecated_at', 'expires_at', 'last_used']: + for field in ["created_at", "activated_at", "deprecated_at", "expires_at", "last_used"]: if data[field]: data[field] = datetime.fromisoformat(data[field]) return cls(**data) @@ -66,76 +68,80 @@ def from_dict(cls, data: Dict) -> 'RSAKeyMetadata': class RSAKeyLifecycleManager: """FIPS-compliant RSA key lifecycle management""" - + def __init__(self): self.key_storage_path = Path("/app/security/keys") self.key_storage_path.mkdir(parents=True, exist_ok=True, mode=0o700) self.key_size = 2048 # FIPS minimum self.key_lifetime_days = 365 # 1 year default self.rotation_overlap_days = 7 # 7 days overlap for smooth transition - + def generate_key_id(self) -> str: """Generate unique key identifier""" return f"jwt_key_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{secrets.token_hex(8)}" - + def calculate_fingerprint(self, public_key) -> str: """Calculate SHA-256 fingerprint of RSA public key""" public_der = public_key.public_bytes( encoding=serialization.Encoding.DER, - format=serialization.PublicFormat.SubjectPublicKeyInfo + format=serialization.PublicFormat.SubjectPublicKeyInfo, ) digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) digest.update(public_der) hash_bytes = digest.finalize() return hash_bytes.hex()[:32] - - def generate_rsa_key_pair(self, key_size: int = None) -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: + + def generate_rsa_key_pair( + self, key_size: int = None + ) -> Tuple[rsa.RSAPrivateKey, rsa.RSAPublicKey]: """ Generate FIPS-compliant RSA key pair - + Args: key_size: RSA key size (default: 2048, recommended: 4096) - + Returns: Tuple of (private_key, public_key) """ if key_size is None: key_size = self.key_size - + if key_size < 2048: raise ValueError("RSA key size must be at least 2048 bits for FIPS compliance") - + try: # Generate RSA private key with FIPS-approved parameters private_key = rsa.generate_private_key( public_exponent=65537, # Standard public exponent key_size=key_size, - backend=default_backend() + backend=default_backend(), ) - + public_key = private_key.public_key() - + logger.info(f"Generated RSA-{key_size} key pair successfully") return private_key, public_key - + except Exception as e: logger.error(f"Failed to generate RSA key pair: {e}") raise - - def store_key_pair(self, - private_key: rsa.RSAPrivateKey, - public_key: rsa.RSAPublicKey, - key_id: str, - metadata: RSAKeyMetadata) -> Tuple[Path, Path]: + + def store_key_pair( + self, + private_key: rsa.RSAPrivateKey, + public_key: rsa.RSAPublicKey, + key_id: str, + metadata: RSAKeyMetadata, + ) -> Tuple[Path, Path]: """ Store RSA key pair securely with metadata - + Args: private_key: RSA private key public_key: RSA public key key_id: Unique key identifier metadata: Key metadata - + Returns: Tuple of (private_key_path, public_key_path) """ @@ -143,111 +149,113 @@ def store_key_pair(self, # Create key-specific directory key_dir = self.key_storage_path / key_id key_dir.mkdir(mode=0o700, exist_ok=True) - + # Store private key private_key_path = key_dir / "jwt_private.pem" with open(private_key_path, "wb") as f: - f.write(private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() - )) + f.write( + private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + ) private_key_path.chmod(0o600) - + # Store public key public_key_path = key_dir / "jwt_public.pem" with open(public_key_path, "wb") as f: - f.write(public_key.public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo - )) + f.write( + public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + ) public_key_path.chmod(0o644) - + # Store metadata metadata_path = key_dir / "metadata.json" import json - with open(metadata_path, 'w') as f: + + with open(metadata_path, "w") as f: json.dump(metadata.to_dict(), f, indent=2) metadata_path.chmod(0o600) - + logger.info(f"Stored key pair {key_id} securely") return private_key_path, public_key_path - + except Exception as e: logger.error(f"Failed to store key pair {key_id}: {e}") raise - + def load_key_metadata(self, key_id: str) -> Optional[RSAKeyMetadata]: """Load key metadata by ID""" try: metadata_path = self.key_storage_path / key_id / "metadata.json" if not metadata_path.exists(): return None - + import json - with open(metadata_path, 'r') as f: + + with open(metadata_path, "r") as f: metadata_dict = json.load(f) - + return RSAKeyMetadata.from_dict(metadata_dict) - + except Exception as e: logger.error(f"Failed to load metadata for key {key_id}: {e}") return None - + def update_key_metadata(self, key_id: str, metadata: RSAKeyMetadata): """Update key metadata""" try: metadata_path = self.key_storage_path / key_id / "metadata.json" if not metadata_path.parent.exists(): raise ValueError(f"Key directory not found for {key_id}") - + import json - with open(metadata_path, 'w') as f: + + with open(metadata_path, "w") as f: json.dump(metadata.to_dict(), f, indent=2) - + except Exception as e: logger.error(f"Failed to update metadata for key {key_id}: {e}") raise - + def load_private_key(self, key_id: str) -> Optional[rsa.RSAPrivateKey]: """Load RSA private key by ID""" try: private_key_path = self.key_storage_path / key_id / "jwt_private.pem" if not private_key_path.exists(): return None - + with open(private_key_path, "rb") as f: private_key = serialization.load_pem_private_key( - f.read(), - password=None, - backend=default_backend() + f.read(), password=None, backend=default_backend() ) - + return private_key - + except Exception as e: logger.error(f"Failed to load private key {key_id}: {e}") return None - + def load_public_key(self, key_id: str) -> Optional[rsa.RSAPublicKey]: """Load RSA public key by ID""" try: public_key_path = self.key_storage_path / key_id / "jwt_public.pem" if not public_key_path.exists(): return None - + with open(public_key_path, "rb") as f: - public_key = serialization.load_pem_public_key( - f.read(), - backend=default_backend() - ) - + public_key = serialization.load_pem_public_key(f.read(), backend=default_backend()) + return public_key - + except Exception as e: logger.error(f"Failed to load public key {key_id}: {e}") return None - + def get_active_key_id(self) -> Optional[str]: """Get the current active key ID""" try: @@ -257,51 +265,51 @@ def get_active_key_id(self) -> Optional[str]: if metadata and metadata.status == KeyStatus.ACTIVE: return key_dir.name return None - + except Exception as e: logger.error(f"Failed to get active key: {e}") return None - + def create_new_key(self, key_size: int = None) -> str: """ Create new RSA key pair and store it - + Args: key_size: RSA key size (default from settings) - + Returns: New key ID """ try: # Generate key pair private_key, public_key = self.generate_rsa_key_pair(key_size) - + # Create metadata key_id = self.generate_key_id() fingerprint = self.calculate_fingerprint(public_key) - + metadata = RSAKeyMetadata( key_id=key_id, key_size=key_size or self.key_size, status=KeyStatus.PENDING, created_at=datetime.utcnow(), - fingerprint=fingerprint + fingerprint=fingerprint, ) - + # Store key pair self.store_key_pair(private_key, public_key, key_id, metadata) - + logger.info(f"Created new RSA key: {key_id}") return key_id - + except Exception as e: logger.error(f"Failed to create new key: {e}") raise - + def activate_key(self, key_id: str): """ Activate a pending key and deprecate the current active key - + Args: key_id: ID of key to activate """ @@ -310,163 +318,168 @@ def activate_key(self, key_id: str): metadata = self.load_key_metadata(key_id) if not metadata: raise ValueError(f"Key {key_id} not found") - + if metadata.status != KeyStatus.PENDING: raise ValueError(f"Key {key_id} is not in pending status") - + # Deprecate current active key current_active = self.get_active_key_id() if current_active: self.deprecate_key(current_active) - + # Activate new key metadata.status = KeyStatus.ACTIVE metadata.activated_at = datetime.utcnow() metadata.expires_at = datetime.utcnow() + timedelta(days=self.key_lifetime_days) - + self.update_key_metadata(key_id, metadata) - + # Update symlinks for current active key self.update_current_key_symlinks(key_id) - + logger.info(f"Activated RSA key: {key_id}") - + except Exception as e: logger.error(f"Failed to activate key {key_id}: {e}") raise - + def deprecate_key(self, key_id: str): """Mark key as deprecated""" try: metadata = self.load_key_metadata(key_id) if not metadata: raise ValueError(f"Key {key_id} not found") - + metadata.status = KeyStatus.DEPRECATED metadata.deprecated_at = datetime.utcnow() - + self.update_key_metadata(key_id, metadata) - + logger.info(f"Deprecated RSA key: {key_id}") - + except Exception as e: logger.error(f"Failed to deprecate key {key_id}: {e}") raise - + def revoke_key(self, key_id: str, reason: str = None): """Revoke a key immediately""" try: metadata = self.load_key_metadata(key_id) if not metadata: raise ValueError(f"Key {key_id} not found") - + metadata.status = KeyStatus.REVOKED - + self.update_key_metadata(key_id, metadata) - + logger.warning(f"Revoked RSA key: {key_id}. Reason: {reason or 'Not specified'}") - + except Exception as e: logger.error(f"Failed to revoke key {key_id}: {e}") raise - + def update_current_key_symlinks(self, key_id: str): """Update symlinks to point to current active key""" try: # Create symlinks for current active key current_private_link = self.key_storage_path / "jwt_private.pem" current_public_link = self.key_storage_path / "jwt_public.pem" - + # Remove existing symlinks if current_private_link.is_symlink(): current_private_link.unlink() if current_public_link.is_symlink(): current_public_link.unlink() - + # Create new symlinks private_target = self.key_storage_path / key_id / "jwt_private.pem" public_target = self.key_storage_path / key_id / "jwt_public.pem" - + current_private_link.symlink_to(private_target) current_public_link.symlink_to(public_target) - + logger.info(f"Updated symlinks to point to key {key_id}") - + except Exception as e: logger.error(f"Failed to update symlinks for key {key_id}: {e}") raise - + def rotate_keys(self, new_key_size: int = None) -> str: """ Perform key rotation with overlap period - + Args: new_key_size: Size for new key (default: current key size) - + Returns: New key ID """ try: # Create new key new_key_id = self.create_new_key(new_key_size) - + # Activate immediately (deprecates old key) self.activate_key(new_key_id) - + logger.info(f"Key rotation completed: New active key {new_key_id}") return new_key_id - + except Exception as e: logger.error(f"Failed to rotate keys: {e}") raise - + def get_keys_needing_rotation(self) -> List[str]: """Get list of keys that need rotation based on expiration""" keys_needing_rotation = [] now = datetime.utcnow() warning_threshold = now + timedelta(days=self.rotation_overlap_days) - + try: for key_dir in self.key_storage_path.iterdir(): if key_dir.is_dir(): metadata = self.load_key_metadata(key_dir.name) - if (metadata and - metadata.status == KeyStatus.ACTIVE and - metadata.expires_at and - metadata.expires_at <= warning_threshold): + if ( + metadata + and metadata.status == KeyStatus.ACTIVE + and metadata.expires_at + and metadata.expires_at <= warning_threshold + ): keys_needing_rotation.append(key_dir.name) - + return keys_needing_rotation - + except Exception as e: logger.error(f"Failed to check keys for rotation: {e}") return [] - + def cleanup_old_keys(self, retention_days: int = 90): """Clean up deprecated/revoked keys older than retention period""" cutoff_date = datetime.utcnow() - timedelta(days=retention_days) cleaned_count = 0 - + try: for key_dir in self.key_storage_path.iterdir(): if key_dir.is_dir(): metadata = self.load_key_metadata(key_dir.name) - if (metadata and - metadata.status in [KeyStatus.DEPRECATED, KeyStatus.REVOKED] and - metadata.created_at < cutoff_date): - + if ( + metadata + and metadata.status in [KeyStatus.DEPRECATED, KeyStatus.REVOKED] + and metadata.created_at < cutoff_date + ): + # Remove key directory import shutil + shutil.rmtree(key_dir) cleaned_count += 1 logger.info(f"Cleaned up old key: {key_dir.name}") - + if cleaned_count > 0: logger.info(f"Cleaned up {cleaned_count} old keys") - + except Exception as e: logger.error(f"Failed during key cleanup: {e}") - + def record_key_usage(self, key_id: str): """Record key usage for analytics""" try: @@ -475,10 +488,10 @@ def record_key_usage(self, key_id: str): metadata.usage_count += 1 metadata.last_used = datetime.utcnow() self.update_key_metadata(key_id, metadata) - + except Exception as e: logger.debug(f"Failed to record key usage for {key_id}: {e}") - + def get_key_statistics(self) -> Dict: """Get key lifecycle statistics""" stats = { @@ -489,21 +502,21 @@ def get_key_statistics(self) -> Dict: "revoked_keys": 0, "keys_needing_rotation": 0, "oldest_active_key_age_days": 0, - "average_key_usage": 0 + "average_key_usage": 0, } - + try: usage_counts = [] now = datetime.utcnow() oldest_active_age = 0 - + for key_dir in self.key_storage_path.iterdir(): if key_dir.is_dir(): metadata = self.load_key_metadata(key_dir.name) if metadata: stats["total_keys"] += 1 usage_counts.append(metadata.usage_count) - + if metadata.status == KeyStatus.ACTIVE: stats["active_keys"] += 1 age_days = (now - metadata.created_at).days @@ -514,15 +527,15 @@ def get_key_statistics(self) -> Dict: stats["deprecated_keys"] += 1 elif metadata.status == KeyStatus.REVOKED: stats["revoked_keys"] += 1 - + stats["keys_needing_rotation"] = len(self.get_keys_needing_rotation()) stats["oldest_active_key_age_days"] = oldest_active_age - + if usage_counts: stats["average_key_usage"] = sum(usage_counts) / len(usage_counts) - + return stats - + except Exception as e: logger.error(f"Failed to get key statistics: {e}") return stats @@ -531,6 +544,7 @@ def get_key_statistics(self) -> Dict: # Global key lifecycle manager instance _key_lifecycle_manager = None + def get_key_lifecycle_manager() -> RSAKeyLifecycleManager: """Get global key lifecycle manager instance""" global _key_lifecycle_manager @@ -544,10 +558,12 @@ def rotate_jwt_keys() -> str: """Rotate JWT signing keys""" return get_key_lifecycle_manager().rotate_keys() + def get_active_jwt_key_id() -> Optional[str]: """Get current active JWT key ID""" return get_key_lifecycle_manager().get_active_key_id() + def check_keys_for_rotation() -> List[str]: """Check for keys needing rotation""" - return get_key_lifecycle_manager().get_keys_needing_rotation() \ No newline at end of file + return get_key_lifecycle_manager().get_keys_needing_rotation() diff --git a/backend/app/services/mfa_service.py b/backend/app/services/mfa_service.py index a8d56a77..aa9b1c5a 100644 --- a/backend/app/services/mfa_service.py +++ b/backend/app/services/mfa_service.py @@ -25,6 +25,7 @@ class MFAMethod(Enum): """Supported MFA methods""" + TOTP = "totp" BACKUP_CODES = "backup_codes" FIDO2 = "fido2" # Framework ready for future implementation @@ -33,6 +34,7 @@ class MFAMethod(Enum): @dataclass class MFAEnrollmentResult: """Result of MFA enrollment process""" + success: bool secret_key: Optional[str] = None qr_code_data: Optional[str] = None @@ -43,6 +45,7 @@ class MFAEnrollmentResult: @dataclass class MFAValidationResult: """Result of MFA validation""" + valid: bool method_used: Optional[MFAMethod] = None backup_code_used: Optional[str] = None @@ -51,43 +54,44 @@ class MFAValidationResult: class MFAService: """FIPS-compliant Multi-Factor Authentication service""" - + def __init__(self): self.issuer_name = "OpenWatch Security Platform" self.backup_code_length = 8 self.backup_code_count = 10 self.totp_window = 1 # Allow 1 time window before/after current - + def generate_totp_secret(self) -> str: """Generate a FIPS-compliant TOTP secret key""" # Generate 160-bit (20 bytes) secret for TOTP (FIPS recommended) secret_bytes = secrets.token_bytes(20) - return base64.b32encode(secret_bytes).decode('utf-8') - + return base64.b32encode(secret_bytes).decode("utf-8") + def generate_backup_codes(self) -> List[str]: """Generate cryptographically secure backup codes""" codes = [] for _ in range(self.backup_code_count): # Generate 8-character alphanumeric backup codes - code = ''.join(secrets.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789') - for _ in range(self.backup_code_length)) + code = "".join( + secrets.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + for _ in range(self.backup_code_length) + ) codes.append(code) return codes - + def hash_backup_code(self, code: str) -> str: """Hash backup code for secure storage""" # Use SHA-256 for backup code hashing (FIPS approved) - return hashlib.sha256(code.encode('utf-8')).hexdigest() - + return hashlib.sha256(code.encode("utf-8")).hexdigest() + def generate_qr_code(self, username: str, secret: str) -> str: """Generate QR code for TOTP setup""" try: # Create TOTP URI totp_uri = pyotp.totp.TOTP(secret).provisioning_uri( - name=username, - issuer_name=self.issuer_name + name=username, issuer_name=self.issuer_name ) - + # Generate QR code qr = qrcode.QRCode( version=1, @@ -97,173 +101,170 @@ def generate_qr_code(self, username: str, secret: str) -> str: ) qr.add_data(totp_uri) qr.make(fit=True) - + # Create QR code image img = qr.make_image(fill_color="black", back_color="white") - + # Convert to base64 for embedding in HTML buffer = BytesIO() - img.save(buffer, format='PNG') + img.save(buffer, format="PNG") img_data = buffer.getvalue() - - return base64.b64encode(img_data).decode('utf-8') - + + return base64.b64encode(img_data).decode("utf-8") + except Exception as e: logger.error(f"Failed to generate QR code: {e}") return None - + def encrypt_mfa_secret(self, secret: str) -> str: """Encrypt MFA secret for database storage""" try: - encrypted = encrypt_data(secret.encode('utf-8')) + encrypted = encrypt_data(secret.encode("utf-8")) return encrypted except Exception as e: logger.error(f"Failed to encrypt MFA secret: {e}") raise - + def decrypt_mfa_secret(self, encrypted_secret: str) -> str: """Decrypt MFA secret from database""" try: decrypted_bytes = decrypt_data(encrypted_secret) - return decrypted_bytes.decode('utf-8') + return decrypted_bytes.decode("utf-8") except Exception as e: logger.error(f"Failed to decrypt MFA secret: {e}") raise - + def validate_totp_code(self, secret: str, user_code: str, used_codes_cache: set = None) -> bool: """ Validate TOTP code with replay protection - + Args: secret: User's TOTP secret user_code: Code provided by user used_codes_cache: Set of recently used codes to prevent replay - + Returns: True if code is valid and not replayed """ try: totp = pyotp.TOTP(secret) - + # Check current time window and adjacent windows current_time = datetime.now() for i in range(-self.totp_window, self.totp_window + 1): test_time = current_time + timedelta(seconds=i * 30) expected_code = totp.at(test_time) - + if user_code == expected_code: # Check for replay attack code_key = f"{user_code}_{int(test_time.timestamp() // 30)}" if used_codes_cache and code_key in used_codes_cache: logger.warning(f"TOTP replay attack detected: {code_key}") return False - + # Add to used codes cache if provided if used_codes_cache is not None: used_codes_cache.add(code_key) - + return True - + return False - + except Exception as e: logger.error(f"TOTP validation error: {e}") return False - - def validate_backup_code(self, hashed_backup_codes: List[str], user_code: str) -> Tuple[bool, str]: + + def validate_backup_code( + self, hashed_backup_codes: List[str], user_code: str + ) -> Tuple[bool, str]: """ Validate backup code against stored hashes - + Args: hashed_backup_codes: List of hashed backup codes from database user_code: Code provided by user - + Returns: Tuple of (is_valid, code_hash_if_valid) """ try: user_code_hash = self.hash_backup_code(user_code.upper()) - + if user_code_hash in hashed_backup_codes: return True, user_code_hash - + return False, None - + except Exception as e: logger.error(f"Backup code validation error: {e}") return False, None - + def enroll_user_mfa(self, username: str) -> MFAEnrollmentResult: """ Enroll user in MFA with TOTP and backup codes - + Args: username: Username for MFA enrollment - + Returns: MFAEnrollmentResult with secret, QR code, and backup codes """ try: # Generate TOTP secret secret = self.generate_totp_secret() - + # Generate QR code qr_code_data = self.generate_qr_code(username, secret) if not qr_code_data: return MFAEnrollmentResult( - success=False, - error_message="Failed to generate QR code" + success=False, error_message="Failed to generate QR code" ) - + # Generate backup codes backup_codes = self.generate_backup_codes() - + return MFAEnrollmentResult( success=True, secret_key=secret, qr_code_data=qr_code_data, - backup_codes=backup_codes + backup_codes=backup_codes, ) - + except Exception as e: logger.error(f"MFA enrollment failed for {username}: {e}") - return MFAEnrollmentResult( - success=False, - error_message=f"Enrollment failed: {str(e)}" - ) - - def validate_mfa_code(self, - encrypted_secret: str, - hashed_backup_codes: List[str], - user_code: str, - used_codes_cache: set = None) -> MFAValidationResult: + return MFAEnrollmentResult(success=False, error_message=f"Enrollment failed: {str(e)}") + + def validate_mfa_code( + self, + encrypted_secret: str, + hashed_backup_codes: List[str], + user_code: str, + used_codes_cache: set = None, + ) -> MFAValidationResult: """ Validate MFA code (TOTP or backup code) - + Args: encrypted_secret: Encrypted TOTP secret from database hashed_backup_codes: List of hashed backup codes user_code: Code provided by user used_codes_cache: Cache of recently used TOTP codes - + Returns: MFAValidationResult with validation status and method used """ try: - user_code = user_code.strip().replace(' ', '').upper() - + user_code = user_code.strip().replace(" ", "").upper() + # First try TOTP validation if len(user_code) == 6 and user_code.isdigit(): try: secret = self.decrypt_mfa_secret(encrypted_secret) if self.validate_totp_code(secret, user_code, used_codes_cache): - return MFAValidationResult( - valid=True, - method_used=MFAMethod.TOTP - ) + return MFAValidationResult(valid=True, method_used=MFAMethod.TOTP) except Exception as e: logger.warning(f"TOTP validation failed: {e}") - + # Try backup code validation if len(user_code) == self.backup_code_length: is_valid, used_code_hash = self.validate_backup_code(hashed_backup_codes, user_code) @@ -271,28 +272,24 @@ def validate_mfa_code(self, return MFAValidationResult( valid=True, method_used=MFAMethod.BACKUP_CODES, - backup_code_used=used_code_hash + backup_code_used=used_code_hash, ) - + return MFAValidationResult( - valid=False, - error_message="Invalid MFA code format or value" + valid=False, error_message="Invalid MFA code format or value" ) - + except Exception as e: logger.error(f"MFA validation error: {e}") - return MFAValidationResult( - valid=False, - error_message=f"Validation failed: {str(e)}" - ) - + return MFAValidationResult(valid=False, error_message=f"Validation failed: {str(e)}") + def regenerate_backup_codes(self, username: str) -> List[str]: """ Generate new backup codes for a user - + Args: username: Username for backup code regeneration - + Returns: List of new backup codes """ @@ -300,18 +297,18 @@ def regenerate_backup_codes(self, username: str) -> List[str]: backup_codes = self.generate_backup_codes() logger.info(f"Regenerated backup codes for user: {username}") return backup_codes - + except Exception as e: logger.error(f"Failed to regenerate backup codes for {username}: {e}") raise - + def get_mfa_status(self, user_data: Dict) -> Dict[str, any]: """ Get user's MFA status and capabilities - + Args: user_data: User data from database - + Returns: Dictionary with MFA status information """ @@ -321,13 +318,14 @@ def get_mfa_status(self, user_data: Dict) -> Dict[str, any]: "backup_codes_available": len(user_data.get("backup_codes", [])), "fido2_enabled": False, # Future implementation "last_mfa_use": user_data.get("last_mfa_use"), - "supported_methods": [MFAMethod.TOTP.value, MFAMethod.BACKUP_CODES.value] + "supported_methods": [MFAMethod.TOTP.value, MFAMethod.BACKUP_CODES.value], } # Global MFA service instance _mfa_service = None + def get_mfa_service() -> MFAService: """Get global MFA service instance""" global _mfa_service @@ -341,10 +339,14 @@ def enroll_mfa(username: str) -> MFAEnrollmentResult: """Enroll user in MFA""" return get_mfa_service().enroll_user_mfa(username) -def validate_mfa(encrypted_secret: str, backup_codes: List[str], code: str, used_codes: set = None) -> MFAValidationResult: + +def validate_mfa( + encrypted_secret: str, backup_codes: List[str], code: str, used_codes: set = None +) -> MFAValidationResult: """Validate MFA code""" return get_mfa_service().validate_mfa_code(encrypted_secret, backup_codes, code, used_codes) + def regenerate_backup_codes(username: str) -> List[str]: """Regenerate backup codes""" - return get_mfa_service().regenerate_backup_codes(username) \ No newline at end of file + return get_mfa_service().regenerate_backup_codes(username) diff --git a/backend/app/services/prometheus_metrics.py b/backend/app/services/prometheus_metrics.py index b5ecd738..308e41f1 100644 --- a/backend/app/services/prometheus_metrics.py +++ b/backend/app/services/prometheus_metrics.py @@ -3,11 +3,17 @@ Comprehensive metrics for monitoring and observability Author: Noah Chen - nc9010@hanalyx.com """ + import time from prometheus_client import ( - Counter, Histogram, Gauge, Info, - CollectorRegistry, generate_latest, - multiprocess, values + Counter, + Histogram, + Gauge, + Info, + CollectorRegistry, + generate_latest, + multiprocess, + values, ) from typing import Dict, Optional import logging @@ -23,389 +29,332 @@ # HTTP Metrics http_requests_total = Counter( - 'secureops_http_requests_total', - 'Total HTTP requests', - ['method', 'endpoint', 'status', 'service'], - registry=registry + "secureops_http_requests_total", + "Total HTTP requests", + ["method", "endpoint", "status", "service"], + registry=registry, ) http_request_duration_seconds = Histogram( - 'secureops_http_request_duration_seconds', - 'HTTP request duration in seconds', - ['method', 'endpoint', 'service'], + "secureops_http_request_duration_seconds", + "HTTP request duration in seconds", + ["method", "endpoint", "service"], buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0], - registry=registry + registry=registry, ) # Service Health Metrics -service_info = Info( - 'secureops_service', - 'Service information', - registry=registry -) +service_info = Info("secureops_service", "Service information", registry=registry) -service_up = Gauge( - 'secureops_service_up', - 'Service up status', - ['service'], - registry=registry -) +service_up = Gauge("secureops_service_up", "Service up status", ["service"], registry=registry) # SCAP Scanning Metrics scans_total = Counter( - 'secureops_scans_total', - 'Total SCAP scans performed', - ['status', 'profile', 'framework'], - registry=registry + "secureops_scans_total", + "Total SCAP scans performed", + ["status", "profile", "framework"], + registry=registry, ) -scans_active = Gauge( - 'secureops_scans_active', - 'Currently active scans', - registry=registry -) +scans_active = Gauge("secureops_scans_active", "Currently active scans", registry=registry) scan_duration_seconds = Histogram( - 'secureops_scan_duration_seconds', - 'SCAP scan duration in seconds', - ['profile', 'framework'], + "secureops_scan_duration_seconds", + "SCAP scan duration in seconds", + ["profile", "framework"], buckets=[10, 30, 60, 120, 300, 600, 1200, 1800, 3600], - registry=registry + registry=registry, ) scan_rules_processed = Counter( - 'secureops_scan_rules_processed_total', - 'Total SCAP rules processed', - ['status', 'severity'], - registry=registry + "secureops_scan_rules_processed_total", + "Total SCAP rules processed", + ["status", "severity"], + registry=registry, ) # Compliance Metrics compliance_score = Gauge( - 'secureops_compliance_score', - 'Compliance score for hosts', - ['host_id', 'framework'], - registry=registry + "secureops_compliance_score", + "Compliance score for hosts", + ["host_id", "framework"], + registry=registry, ) compliance_rules_failed = Gauge( - 'secureops_compliance_rules_failed', - 'Number of failed compliance rules', - ['host_id', 'severity', 'framework'], - registry=registry + "secureops_compliance_rules_failed", + "Number of failed compliance rules", + ["host_id", "severity", "framework"], + registry=registry, ) # Host Management Metrics hosts_total = Gauge( - 'secureops_hosts_total', - 'Total number of managed hosts', - ['status'], - registry=registry + "secureops_hosts_total", "Total number of managed hosts", ["status"], registry=registry ) host_connectivity_checks = Counter( - 'secureops_host_connectivity_checks_total', - 'Total host connectivity checks', - ['result'], - registry=registry + "secureops_host_connectivity_checks_total", + "Total host connectivity checks", + ["result"], + registry=registry, ) # Integration Metrics integration_calls_total = Counter( - 'secureops_integration_calls_total', - 'Total integration calls', - ['target', 'endpoint', 'status'], - registry=registry + "secureops_integration_calls_total", + "Total integration calls", + ["target", "endpoint", "status"], + registry=registry, ) integration_call_duration_seconds = Histogram( - 'secureops_integration_call_duration_seconds', - 'Integration call duration in seconds', - ['target', 'endpoint'], + "secureops_integration_call_duration_seconds", + "Integration call duration in seconds", + ["target", "endpoint"], buckets=[0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0], - registry=registry + registry=registry, ) # Remediation Metrics (for AEGIS integration) remediations_total = Counter( - 'secureops_remediations_total', - 'Total remediation attempts', - ['status', 'severity'], - registry=registry + "secureops_remediations_total", + "Total remediation attempts", + ["status", "severity"], + registry=registry, ) remediation_duration_seconds = Histogram( - 'secureops_remediation_duration_seconds', - 'Remediation duration in seconds', - ['severity'], + "secureops_remediation_duration_seconds", + "Remediation duration in seconds", + ["severity"], buckets=[1, 5, 10, 30, 60, 120, 300, 600], - registry=registry + registry=registry, ) # Database Metrics database_connections_active = Gauge( - 'secureops_database_connections_active', - 'Active database connections', - registry=registry + "secureops_database_connections_active", "Active database connections", registry=registry ) database_query_duration_seconds = Histogram( - 'secureops_database_query_duration_seconds', - 'Database query duration in seconds', - ['operation'], + "secureops_database_query_duration_seconds", + "Database query duration in seconds", + ["operation"], buckets=[0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0], - registry=registry + registry=registry, ) # Security Event Metrics security_events_total = Counter( - 'secureops_security_events_total', - 'Total security events', - ['event_type', 'severity'], - registry=registry + "secureops_security_events_total", + "Total security events", + ["event_type", "severity"], + registry=registry, ) authentication_attempts_total = Counter( - 'secureops_authentication_attempts_total', - 'Total authentication attempts', - ['result', 'method'], - registry=registry + "secureops_authentication_attempts_total", + "Total authentication attempts", + ["result", "method"], + registry=registry, ) # System Resource Metrics system_cpu_usage_percent = Gauge( - 'secureops_system_cpu_usage_percent', - 'System CPU usage percentage', - registry=registry + "secureops_system_cpu_usage_percent", "System CPU usage percentage", registry=registry ) system_memory_usage_bytes = Gauge( - 'secureops_system_memory_usage_bytes', - 'System memory usage in bytes', - ['type'], # available, used, total - registry=registry + "secureops_system_memory_usage_bytes", + "System memory usage in bytes", + ["type"], # available, used, total + registry=registry, ) system_disk_usage_bytes = Gauge( - 'secureops_system_disk_usage_bytes', - 'System disk usage in bytes', - ['device', 'type'], # type: free, used, total - registry=registry + "secureops_system_disk_usage_bytes", + "System disk usage in bytes", + ["device", "type"], # type: free, used, total + registry=registry, ) # Business Logic Metrics workflow_duration_seconds = Histogram( - 'secureops_workflow_duration_seconds', - 'End-to-end workflow duration', - ['workflow_type'], + "secureops_workflow_duration_seconds", + "End-to-end workflow duration", + ["workflow_type"], buckets=[60, 300, 600, 1200, 1800, 3600, 7200], - registry=registry + registry=registry, ) queue_size = Gauge( - 'secureops_queue_size', - 'Queue size for background tasks', - ['queue_name'], - registry=registry + "secureops_queue_size", "Queue size for background tasks", ["queue_name"], registry=registry ) # Error Rate Metrics error_rate = Gauge( - 'secureops_error_rate', - 'Error rate for various operations', - ['operation'], - registry=registry + "secureops_error_rate", "Error rate for various operations", ["operation"], registry=registry ) class PrometheusMetrics: """Centralized metrics collection and management""" - + def __init__(self): self.start_time = time.time() # Initialize service info - service_info.info({ - 'version': '1.0.0', - 'service': 'openwatch', - 'environment': 'production' - }) - - def record_http_request(self, method: str, endpoint: str, status_code: int, - duration: float, service: str = "openwatch"): + service_info.info({"version": "1.0.0", "service": "openwatch", "environment": "production"}) + + def record_http_request( + self, + method: str, + endpoint: str, + status_code: int, + duration: float, + service: str = "openwatch", + ): """Record HTTP request metrics""" http_requests_total.labels( - method=method, - endpoint=endpoint, - status=str(status_code), - service=service + method=method, endpoint=endpoint, status=str(status_code), service=service ).inc() - + http_request_duration_seconds.labels( - method=method, - endpoint=endpoint, - service=service + method=method, endpoint=endpoint, service=service ).observe(duration) - - def record_scan_metrics(self, status: str, profile: str, framework: str, - duration: Optional[float] = None, rules_processed: Dict[str, int] = None): + + def record_scan_metrics( + self, + status: str, + profile: str, + framework: str, + duration: Optional[float] = None, + rules_processed: Dict[str, int] = None, + ): """Record SCAP scan metrics""" - scans_total.labels( - status=status, - profile=profile, - framework=framework - ).inc() - + scans_total.labels(status=status, profile=profile, framework=framework).inc() + if duration is not None: - scan_duration_seconds.labels( - profile=profile, - framework=framework - ).observe(duration) - + scan_duration_seconds.labels(profile=profile, framework=framework).observe(duration) + if rules_processed: for rule_status, count in rules_processed.items(): - severity = rules_processed.get('severity', 'medium') - scan_rules_processed.labels( - status=rule_status, - severity=severity - ).inc(count) - + severity = rules_processed.get("severity", "medium") + scan_rules_processed.labels(status=rule_status, severity=severity).inc(count) + def update_compliance_score(self, host_id: str, framework: str, score: float): """Update compliance score for a host""" - compliance_score.labels( - host_id=host_id, - framework=framework - ).set(score) - - def update_compliance_failures(self, host_id: str, framework: str, - severity_counts: Dict[str, int]): + compliance_score.labels(host_id=host_id, framework=framework).set(score) + + def update_compliance_failures( + self, host_id: str, framework: str, severity_counts: Dict[str, int] + ): """Update compliance failure counts by severity""" for severity, count in severity_counts.items(): compliance_rules_failed.labels( - host_id=host_id, - severity=severity, - framework=framework + host_id=host_id, severity=severity, framework=framework ).set(count) - + def record_host_connectivity(self, result: str): """Record host connectivity check result""" host_connectivity_checks.labels(result=result).inc() - + def update_host_counts(self, status_counts: Dict[str, int]): """Update host count metrics by status""" for status, count in status_counts.items(): hosts_total.labels(status=status).set(count) - - def record_integration_call(self, target: str, endpoint: str, status: str, - duration: float): + + def record_integration_call(self, target: str, endpoint: str, status: str, duration: float): """Record integration call metrics""" - integration_calls_total.labels( - target=target, - endpoint=endpoint, - status=status - ).inc() - - integration_call_duration_seconds.labels( - target=target, - endpoint=endpoint - ).observe(duration) - + integration_calls_total.labels(target=target, endpoint=endpoint, status=status).inc() + + integration_call_duration_seconds.labels(target=target, endpoint=endpoint).observe(duration) + def record_remediation(self, status: str, severity: str, duration: Optional[float] = None): """Record remediation metrics""" - remediations_total.labels( - status=status, - severity=severity - ).inc() - + remediations_total.labels(status=status, severity=severity).inc() + if duration is not None: - remediation_duration_seconds.labels( - severity=severity - ).observe(duration) - + remediation_duration_seconds.labels(severity=severity).observe(duration) + def record_security_event(self, event_type: str, severity: str = "medium"): """Record security event""" - security_events_total.labels( - event_type=event_type, - severity=severity - ).inc() - + security_events_total.labels(event_type=event_type, severity=severity).inc() + def record_authentication_attempt(self, result: str, method: str = "jwt"): """Record authentication attempt""" - authentication_attempts_total.labels( - result=result, - method=method - ).inc() - + authentication_attempts_total.labels(result=result, method=method).inc() + def update_system_metrics(self): """Update system resource metrics""" try: # CPU usage cpu_percent = psutil.cpu_percent(interval=1) system_cpu_usage_percent.set(cpu_percent) - + # Memory usage memory = psutil.virtual_memory() system_memory_usage_bytes.labels(type="total").set(memory.total) system_memory_usage_bytes.labels(type="used").set(memory.used) system_memory_usage_bytes.labels(type="available").set(memory.available) - + # Disk usage for partition in psutil.disk_partitions(): try: usage = psutil.disk_usage(partition.mountpoint) - system_disk_usage_bytes.labels( - device=partition.device, - type="total" - ).set(usage.total) - system_disk_usage_bytes.labels( - device=partition.device, - type="used" - ).set(usage.used) - system_disk_usage_bytes.labels( - device=partition.device, - type="free" - ).set(usage.free) + system_disk_usage_bytes.labels(device=partition.device, type="total").set( + usage.total + ) + system_disk_usage_bytes.labels(device=partition.device, type="used").set( + usage.used + ) + system_disk_usage_bytes.labels(device=partition.device, type="free").set( + usage.free + ) except PermissionError: # Skip inaccessible partitions continue - + except Exception as e: logger.error(f"Error updating system metrics: {e}") - + async def update_database_metrics(self, db: Session): """Update database-related metrics""" try: # Active connections - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT count(*) as active_connections FROM pg_stat_activity WHERE state = 'active' - """)) - + """ + ) + ) + row = result.fetchone() if row: database_connections_active.set(row.active_connections) - + except Exception as e: logger.error(f"Error updating database metrics: {e}") - + def update_queue_metrics(self, queue_name: str, size: int): """Update queue size metrics""" queue_size.labels(queue_name=queue_name).set(size) - + def record_workflow_duration(self, workflow_type: str, duration: float): """Record workflow duration""" - workflow_duration_seconds.labels( - workflow_type=workflow_type - ).observe(duration) - + workflow_duration_seconds.labels(workflow_type=workflow_type).observe(duration) + def set_active_scans(self, count: int): """Set number of active scans""" scans_active.set(count) - + def set_service_up(self, service: str, is_up: bool): """Set service up status""" service_up.labels(service=service).set(1 if is_up else 0) - + def get_metrics(self) -> str: """Get all metrics in Prometheus format""" try: @@ -423,4 +372,4 @@ def get_metrics(self) -> str: def get_metrics_instance() -> PrometheusMetrics: """Get the global metrics instance""" - return metrics \ No newline at end of file + return metrics diff --git a/backend/app/services/rule_specific_scanner.py b/backend/app/services/rule_specific_scanner.py index 12f98c83..4fa0e4a8 100644 --- a/backend/app/services/rule_specific_scanner.py +++ b/backend/app/services/rule_specific_scanner.py @@ -2,6 +2,7 @@ Rule-Specific Scanner Service Enables targeted scanning of specific SCAP rules for efficient remediation verification """ + import logging import subprocess import tempfile @@ -20,35 +21,41 @@ class RuleSpecificScanner: """Service for scanning specific SCAP rules""" - + def __init__(self, results_dir: str = "/app/data/results/rule_scans"): self.results_dir = Path(results_dir) self.results_dir.mkdir(parents=True, exist_ok=True) self.scanner = SCAPScanner() self.framework_mapper = ComplianceFrameworkMapper() self.executor = ThreadPoolExecutor(max_workers=5) - + def _sanitize_identifier(self, identifier: str) -> str: """ Security Fix: Sanitize identifiers to prevent path injection Only allow alphanumeric characters, hyphens, and underscores """ import re + # Remove any characters that aren't alphanumeric, hyphens, or underscores - sanitized = re.sub(r'[^a-zA-Z0-9\-_]', '_', identifier) + sanitized = re.sub(r"[^a-zA-Z0-9\-_]", "_", identifier) # Limit length to prevent excessively long paths return sanitized[:50] - - async def scan_specific_rules(self, host_id: str, content_path: str, - profile_id: str, rule_ids: List[str], - connection_params: Optional[Dict] = None) -> Dict: + + async def scan_specific_rules( + self, + host_id: str, + content_path: str, + profile_id: str, + rule_ids: List[str], + connection_params: Optional[Dict] = None, + ) -> Dict: """Scan specific rules on a host""" try: # Security Fix: Sanitize host_id to prevent path injection sanitized_host_id = self._sanitize_identifier(host_id) scan_id = f"rule_scan_{sanitized_host_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" logger.info(f"Starting rule-specific scan {scan_id} for {len(rule_ids)} rules") - + # Create scan results structure results = { "scan_id": scan_id, @@ -62,11 +69,11 @@ async def scan_specific_rules(self, host_id: str, content_path: str, "error_rules": 0, "rule_results": [], "scan_type": "rule_specific", - "duration_seconds": 0 + "duration_seconds": 0, } - + start_time = datetime.now() - + # Determine if local or remote scan if connection_params: results["scan_mode"] = "remote" @@ -78,14 +85,14 @@ async def scan_specific_rules(self, host_id: str, content_path: str, scan_results = await self._scan_rules_local( scan_id, content_path, profile_id, rule_ids ) - + # Process results for rule_id, rule_result in scan_results.items(): results["scanned_rules"] += 1 - + # Get compliance framework mappings framework_info = self.framework_mapper.get_unified_control(rule_id) - + rule_entry = { "rule_id": rule_id, "result": rule_result.get("result", "error"), @@ -93,20 +100,24 @@ async def scan_specific_rules(self, host_id: str, content_path: str, "severity": rule_result.get("severity", "unknown"), "scan_output": rule_result.get("output", ""), "error": rule_result.get("error", None), - "compliance_frameworks": [] + "compliance_frameworks": [], } - + # Add framework mappings if framework_info: for mapping in framework_info.frameworks: - rule_entry["compliance_frameworks"].append({ - "framework": mapping.framework.value, - "control_id": mapping.control_id, - "control_title": mapping.control_title - }) - rule_entry["automated_remediation_available"] = framework_info.automated_remediation + rule_entry["compliance_frameworks"].append( + { + "framework": mapping.framework.value, + "control_id": mapping.control_id, + "control_title": mapping.control_title, + } + ) + rule_entry["automated_remediation_available"] = ( + framework_info.automated_remediation + ) rule_entry["aegis_rule_id"] = framework_info.aegis_rule_id - + # Count results if rule_result.get("result") == "pass": results["passed_rules"] += 1 @@ -114,82 +125,93 @@ async def scan_specific_rules(self, host_id: str, content_path: str, results["failed_rules"] += 1 else: results["error_rules"] += 1 - + results["rule_results"].append(rule_entry) - + # Calculate duration end_time = datetime.now() results["duration_seconds"] = (end_time - start_time).total_seconds() - + # Calculate compliance score if results["scanned_rules"] > 0: - results["compliance_score"] = (results["passed_rules"] / results["scanned_rules"]) * 100 + results["compliance_score"] = ( + results["passed_rules"] / results["scanned_rules"] + ) * 100 else: results["compliance_score"] = 0 - + # Save results await self._save_scan_results(results) - + logger.info(f"Rule-specific scan completed: {scan_id}") return results - + except Exception as e: logger.error(f"Error in rule-specific scan: {e}") raise ScanExecutionError(f"Rule scan failed: {str(e)}") - - async def scan_failed_rules_from_previous_scan(self, previous_scan_id: str, - content_path: str, - connection_params: Optional[Dict] = None) -> Dict: + + async def scan_failed_rules_from_previous_scan( + self, previous_scan_id: str, content_path: str, connection_params: Optional[Dict] = None + ) -> Dict: """Re-scan only failed rules from a previous scan""" try: # Load previous scan results previous_results = await self._load_scan_results(previous_scan_id) - + if not previous_results: raise ValueError(f"Previous scan {previous_scan_id} not found") - + # Extract failed rule IDs failed_rules = [] for rule in previous_results.get("failed_rules", []): failed_rules.append(rule["rule_id"]) - + if not failed_rules: return { "message": "No failed rules to re-scan", - "previous_scan_id": previous_scan_id + "previous_scan_id": previous_scan_id, } - - logger.info(f"Re-scanning {len(failed_rules)} failed rules from scan {previous_scan_id}") - + + logger.info( + f"Re-scanning {len(failed_rules)} failed rules from scan {previous_scan_id}" + ) + # Perform targeted scan return await self.scan_specific_rules( host_id=previous_results.get("host_id"), content_path=content_path, profile_id=previous_results.get("profile_id"), rule_ids=failed_rules, - connection_params=connection_params + connection_params=connection_params, ) - + except Exception as e: logger.error(f"Error re-scanning failed rules: {e}") raise - - async def verify_remediation(self, host_id: str, content_path: str, - aegis_remediation_id: str, remediated_rules: List[str], - connection_params: Optional[Dict] = None) -> Dict: + + async def verify_remediation( + self, + host_id: str, + content_path: str, + aegis_remediation_id: str, + remediated_rules: List[str], + connection_params: Optional[Dict] = None, + ) -> Dict: """Verify specific rules after AEGIS remediation""" try: - logger.info(f"Verifying remediation {aegis_remediation_id} for {len(remediated_rules)} rules") - + logger.info( + f"Verifying remediation {aegis_remediation_id} for {len(remediated_rules)} rules" + ) + # Create verification scan scan_results = await self.scan_specific_rules( host_id=host_id, content_path=content_path, profile_id="remediation_verification", rule_ids=remediated_rules, - connection_params=connection_params + connection_params=connection_params, ) - + # Analyze remediation effectiveness verification_report = { "remediation_id": aegis_remediation_id, @@ -200,88 +222,93 @@ async def verify_remediation(self, host_id: str, content_path: str, "failed_remediation": scan_results["failed_rules"], "remediation_success_rate": 0, "failed_rules": [], - "successful_rules": [] + "successful_rules": [], } - + # Calculate success rate if remediation_report["total_rules_remediated"] > 0: verification_report["remediation_success_rate"] = ( - verification_report["successfully_remediated"] / - verification_report["total_rules_remediated"] + verification_report["successfully_remediated"] + / verification_report["total_rules_remediated"] ) * 100 - + # Categorize results for rule_result in scan_results["rule_results"]: if rule_result["result"] == "pass": - verification_report["successful_rules"].append({ - "rule_id": rule_result["rule_id"], - "title": rule_result["title"] - }) + verification_report["successful_rules"].append( + {"rule_id": rule_result["rule_id"], "title": rule_result["title"]} + ) else: - verification_report["failed_rules"].append({ - "rule_id": rule_result["rule_id"], - "title": rule_result["title"], - "error": rule_result.get("error", "Remediation not effective") - }) - + verification_report["failed_rules"].append( + { + "rule_id": rule_result["rule_id"], + "title": rule_result["title"], + "error": rule_result.get("error", "Remediation not effective"), + } + ) + return verification_report - + except Exception as e: logger.error(f"Error verifying remediation: {e}") raise - - async def get_rule_scan_history(self, rule_id: str, host_id: Optional[str] = None, - limit: int = 10) -> List[Dict]: + + async def get_rule_scan_history( + self, rule_id: str, host_id: Optional[str] = None, limit: int = 10 + ) -> List[Dict]: """Get scan history for a specific rule""" try: history = [] - + # Search through recent scan results scan_files = sorted(self.results_dir.glob("*.json"), reverse=True)[:100] - + for scan_file in scan_files: try: - with open(scan_file, 'r') as f: + with open(scan_file, "r") as f: scan_data = json.load(f) - + # Filter by host if specified if host_id and scan_data.get("host_id") != host_id: continue - + # Look for the rule in results for rule_result in scan_data.get("rule_results", []): if rule_result["rule_id"] == rule_id: - history.append({ - "scan_id": scan_data["scan_id"], - "timestamp": scan_data["timestamp"], - "host_id": scan_data["host_id"], - "result": rule_result["result"], - "severity": rule_result["severity"] - }) + history.append( + { + "scan_id": scan_data["scan_id"], + "timestamp": scan_data["timestamp"], + "host_id": scan_data["host_id"], + "result": rule_result["result"], + "severity": rule_result["severity"], + } + ) break - + if len(history) >= limit: break - + except Exception as e: logger.warning(f"Error reading scan file {scan_file}: {e}") continue - + return history - + except Exception as e: logger.error(f"Error getting rule scan history: {e}") return [] - - async def _scan_rules_local(self, scan_id: str, content_path: str, - profile_id: str, rule_ids: List[str]) -> Dict[str, Dict]: + + async def _scan_rules_local( + self, scan_id: str, content_path: str, profile_id: str, rule_ids: List[str] + ) -> Dict[str, Dict]: """Scan specific rules locally""" results = {} - + # Create temporary directory for individual rule scans with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) - + # Scan each rule individually for detailed results tasks = [] for rule_id in rule_ids: @@ -289,40 +316,42 @@ async def _scan_rules_local(self, scan_id: str, content_path: str, scan_id, content_path, profile_id, rule_id, temp_path ) tasks.append(task) - + # Execute scans concurrently rule_results = await asyncio.gather(*tasks, return_exceptions=True) - + # Process results for rule_id, result in zip(rule_ids, rule_results): if isinstance(result, Exception): - results[rule_id] = { - "result": "error", - "error": str(result) - } + results[rule_id] = {"result": "error", "error": str(result)} else: results[rule_id] = result - + return results - - async def _scan_single_rule_local(self, scan_id: str, content_path: str, - profile_id: str, rule_id: str, - temp_dir: Path) -> Dict: + + async def _scan_single_rule_local( + self, scan_id: str, content_path: str, profile_id: str, rule_id: str, temp_dir: Path + ) -> Dict: """Scan a single rule locally""" try: # Create unique result files for this rule rule_scan_id = f"{scan_id}_{rule_id.replace(':', '_')}" xml_result = temp_dir / f"{rule_scan_id}.xml" - + # Run oscap with specific rule cmd = [ - 'oscap', 'xccdf', 'eval', - '--profile', profile_id, - '--rule', rule_id, - '--results', str(xml_result), - content_path + "oscap", + "xccdf", + "eval", + "--profile", + profile_id, + "--rule", + rule_id, + "--results", + str(xml_result), + content_path, ] - + # Execute in thread pool to avoid blocking loop = asyncio.get_event_loop() result = await loop.run_in_executor( @@ -332,13 +361,13 @@ async def _scan_single_rule_local(self, scan_id: str, content_path: str, subprocess.PIPE, subprocess.PIPE, True, # capture_output - 300 # timeout + 300, # timeout ) - + # Parse result if xml_result.exists(): scan_result = self.scanner._parse_scan_results(str(xml_result)) - + # Extract rule-specific result for rule_detail in scan_result.get("rule_details", []): if rule_detail["rule_id"] == rule_id: @@ -346,67 +375,74 @@ async def _scan_single_rule_local(self, scan_id: str, content_path: str, "result": rule_detail["result"], "title": rule_detail.get("title", ""), "severity": rule_detail.get("severity", "unknown"), - "output": result.stdout + "output": result.stdout, } - + # If we couldn't find the result, check exit code if result.returncode == 0: return {"result": "pass", "output": result.stdout} else: return {"result": "fail", "output": result.stdout, "error": result.stderr} - + except subprocess.TimeoutExpired: return {"result": "error", "error": "Scan timeout"} except Exception as e: return {"result": "error", "error": str(e)} - - async def _scan_rules_remote(self, scan_id: str, content_path: str, - profile_id: str, rule_ids: List[str], - connection_params: Dict) -> Dict[str, Dict]: + + async def _scan_rules_remote( + self, + scan_id: str, + content_path: str, + profile_id: str, + rule_ids: List[str], + connection_params: Dict, + ) -> Dict[str, Dict]: """Scan specific rules on remote host""" results = {} - + # For remote scanning, we'll batch rules for efficiency # but still provide individual results batch_size = 10 - + for i in range(0, len(rule_ids), batch_size): - batch_rules = rule_ids[i:i + batch_size] - + batch_rules = rule_ids[i : i + batch_size] + try: # Perform batch scan batch_results = await self._scan_rule_batch_remote( scan_id, content_path, profile_id, batch_rules, connection_params ) - + results.update(batch_results) - + except Exception as e: # If batch fails, mark all rules in batch as error for rule_id in batch_rules: - results[rule_id] = { - "result": "error", - "error": f"Batch scan failed: {str(e)}" - } - + results[rule_id] = {"result": "error", "error": f"Batch scan failed: {str(e)}"} + return results - - async def _scan_rule_batch_remote(self, scan_id: str, content_path: str, - profile_id: str, rule_ids: List[str], - connection_params: Dict) -> Dict[str, Dict]: + + async def _scan_rule_batch_remote( + self, + scan_id: str, + content_path: str, + profile_id: str, + rule_ids: List[str], + connection_params: Dict, + ) -> Dict[str, Dict]: """Scan a batch of rules on remote host""" try: # Use the main scanner for remote execution # This will use oscap-ssh or paramiko depending on auth method - + batch_scan_id = f"{scan_id}_batch_{datetime.now().strftime('%H%M%S%f')}" - + # Create a custom command that includes all rules # Note: OpenSCAP doesn't support multiple --rule flags, # so we need to run separate scans or use a custom profile - + results = {} - + for rule_id in rule_ids: result = self.scanner.execute_remote_scan( hostname=connection_params["hostname"], @@ -417,9 +453,9 @@ async def _scan_rule_batch_remote(self, scan_id: str, content_path: str, content_path=content_path, profile_id=profile_id, scan_id=f"{batch_scan_id}_{rule_id.replace(':', '_')}", - rule_id=rule_id + rule_id=rule_id, ) - + # Extract rule-specific result if "rule_details" in result: for rule_detail in result["rule_details"]: @@ -428,36 +464,36 @@ async def _scan_rule_batch_remote(self, scan_id: str, content_path: str, "result": rule_detail["result"], "title": rule_detail.get("title", ""), "severity": rule_detail.get("severity", "unknown"), - "output": result.get("stdout", "") + "output": result.get("stdout", ""), } break else: # Fallback based on exit code results[rule_id] = { "result": "pass" if result.get("exit_code") == 0 else "fail", - "output": result.get("stdout", "") + "output": result.get("stdout", ""), } - + return results - + except Exception as e: logger.error(f"Error in remote rule batch scan: {e}") raise - + async def _save_scan_results(self, results: Dict): """Save scan results to file""" try: result_file = self.results_dir / f"{results['scan_id']}.json" - + async with asyncio.Lock(): - with open(result_file, 'w') as f: + with open(result_file, "w") as f: json.dump(results, f, indent=2) - + logger.info(f"Saved scan results to {result_file}") - + except Exception as e: logger.error(f"Error saving scan results: {e}") - + async def _load_scan_results(self, scan_id: str) -> Optional[Dict]: """Load scan results from file""" try: @@ -465,33 +501,33 @@ async def _load_scan_results(self, scan_id: str) -> Optional[Dict]: sanitized_scan_id = self._sanitize_identifier(scan_id) # First try exact match result_file = self.results_dir / f"{sanitized_scan_id}.json" - + if not result_file.exists(): # Try searching in main results directory main_results = Path("/app/data/results") / sanitized_scan_id if main_results.exists(): # Look for results.json in scan directory result_file = main_results / "results.json" - + if result_file.exists(): - with open(result_file, 'r') as f: + with open(result_file, "r") as f: return json.load(f) - + return None - + except Exception as e: logger.error(f"Error loading scan results: {e}") return None - + def get_rule_remediation_guidance(self, rule_id: str) -> Optional[Dict]: """Get remediation guidance for a specific rule""" try: # Get framework mappings control = self.framework_mapper.get_unified_control(rule_id) - + if not control: return None - + guidance = { "rule_id": rule_id, "title": control.title, @@ -499,30 +535,31 @@ def get_rule_remediation_guidance(self, rule_id: str) -> Optional[Dict]: "aegis_rule_id": control.aegis_rule_id, "implementation_guidance": [], "assessment_objectives": [], - "references": [] + "references": [], } - + # Collect guidance from all frameworks for mapping in control.frameworks: - guidance["implementation_guidance"].append({ - "framework": mapping.framework.value, - "guidance": mapping.implementation_guidance - }) - + guidance["implementation_guidance"].append( + { + "framework": mapping.framework.value, + "guidance": mapping.implementation_guidance, + } + ) + guidance["assessment_objectives"].extend(mapping.assessment_objectives) - + if mapping.related_controls: - guidance["references"].extend([ - f"{mapping.framework.value}: {ctrl}" - for ctrl in mapping.related_controls - ]) - + guidance["references"].extend( + [f"{mapping.framework.value}: {ctrl}" for ctrl in mapping.related_controls] + ) + # Remove duplicates guidance["assessment_objectives"] = list(set(guidance["assessment_objectives"])) guidance["references"] = list(set(guidance["references"])) - + return guidance - + except Exception as e: logger.error(f"Error getting remediation guidance: {e}") - return None \ No newline at end of file + return None diff --git a/backend/app/services/scan_intelligence.py b/backend/app/services/scan_intelligence.py index 3ba3a3ba..9c2da9aa 100644 --- a/backend/app/services/scan_intelligence.py +++ b/backend/app/services/scan_intelligence.py @@ -2,6 +2,7 @@ Scan Intelligence Service Provides intelligent scanning capabilities including profile suggestion and optimization """ + import logging from typing import Dict, List, Optional, Tuple from enum import Enum @@ -23,6 +24,7 @@ class ScanPriority(Enum): @dataclass class HostInfo: """Host information for intelligent scanning decisions""" + id: str hostname: str ip_address: str @@ -38,6 +40,7 @@ class HostInfo: @dataclass class ProfileSuggestion: """Suggested scan profile with reasoning""" + profile_id: str content_id: int name: str @@ -50,37 +53,31 @@ class ProfileSuggestion: class ScanIntelligenceService: """Service for intelligent scan decision making""" - + # Default profile mappings by OS OS_DEFAULT_PROFILES = { "rhel": "xccdf_org.ssgproject.content_profile_cui", - "centos": "xccdf_org.ssgproject.content_profile_cui", + "centos": "xccdf_org.ssgproject.content_profile_cui", "fedora": "xccdf_org.ssgproject.content_profile_cui", "ubuntu": "xccdf_org.ssgproject.content_profile_cis_level1_server", "debian": "xccdf_org.ssgproject.content_profile_cis_level1_server", "sles": "xccdf_org.ssgproject.content_profile_stig", - "windows": "xccdf_org.ssgproject.content_profile_cui" + "windows": "xccdf_org.ssgproject.content_profile_cui", } - + # Compliance profiles by environment type ENVIRONMENT_PROFILES = { "production": { "federal": "xccdf_org.ssgproject.content_profile_stig", "healthcare": "xccdf_org.ssgproject.content_profile_hipaa", "financial": "xccdf_org.ssgproject.content_profile_pci", - "default": "xccdf_org.ssgproject.content_profile_cui" - }, - "staging": { - "default": "xccdf_org.ssgproject.content_profile_cis_level1_server" + "default": "xccdf_org.ssgproject.content_profile_cui", }, - "development": { - "default": "xccdf_org.ssgproject.content_profile_essential" - }, - "test": { - "default": "xccdf_org.ssgproject.content_profile_essential" - } + "staging": {"default": "xccdf_org.ssgproject.content_profile_cis_level1_server"}, + "development": {"default": "xccdf_org.ssgproject.content_profile_essential"}, + "test": {"default": "xccdf_org.ssgproject.content_profile_essential"}, } - + # Tag-based profile mappings TAG_PROFILE_MAPPINGS = { "web": "xccdf_org.ssgproject.content_profile_cui", @@ -89,7 +86,7 @@ class ScanIntelligenceService: "medical": "xccdf_org.ssgproject.content_profile_hipaa", "public": "xccdf_org.ssgproject.content_profile_cui", "dmz": "xccdf_org.ssgproject.content_profile_stig", - "critical": "xccdf_org.ssgproject.content_profile_stig" + "critical": "xccdf_org.ssgproject.content_profile_stig", } def __init__(self, db: Session): @@ -98,51 +95,53 @@ def __init__(self, db: Session): async def suggest_scan_profile(self, host_id: str) -> ProfileSuggestion: """ Intelligently suggest the best scan profile for a host - + Args: host_id: UUID of the host to analyze - + Returns: ProfileSuggestion with recommended profile and reasoning """ try: # Get host information host_info = await self._get_host_info(host_id) - + if not host_info: raise ValueError(f"Host {host_id} not found") - + # Analyze host characteristics suggestions = [] - + # 1. Environment-based suggestion env_suggestion = self._suggest_by_environment(host_info) if env_suggestion: suggestions.append(env_suggestion) - + # 2. Tag-based suggestion tag_suggestion = self._suggest_by_tags(host_info) if tag_suggestion: suggestions.append(tag_suggestion) - + # 3. Owner-based suggestion owner_suggestion = self._suggest_by_owner(host_info) if owner_suggestion: suggestions.append(owner_suggestion) - + # 4. OS-based fallback os_suggestion = self._suggest_by_os(host_info) suggestions.append(os_suggestion) - + # Select the best suggestion best_suggestion = self._select_best_suggestion(suggestions, host_info) - + # Enhance with content metadata enhanced_suggestion = await self._enhance_suggestion_with_content(best_suggestion) - - logger.info(f"Profile suggested for host {host_id}: {enhanced_suggestion.profile_id} (confidence: {enhanced_suggestion.confidence})") + + logger.info( + f"Profile suggested for host {host_id}: {enhanced_suggestion.profile_id} (confidence: {enhanced_suggestion.confidence})" + ) return enhanced_suggestion - + except Exception as e: logger.error(f"Error suggesting profile for host {host_id}: {e}") # Return safe fallback @@ -151,7 +150,9 @@ async def suggest_scan_profile(self, host_id: str) -> ProfileSuggestion: async def _get_host_info(self, host_id: str) -> Optional[HostInfo]: """Retrieve comprehensive host information""" try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT h.id, h.hostname, h.ip_address, h.operating_system, h.environment, h.tags, h.owner, h.port, @@ -170,16 +171,19 @@ async def _get_host_info(self, host_id: str) -> Optional[HostInfo]: ORDER BY started_at DESC LIMIT 1 ) WHERE h.id = :host_id AND h.is_active = true - """), {"host_id": host_id}).fetchone() - + """ + ), + {"host_id": host_id}, + ).fetchone() + if not result: return None - + # Parse tags tags = [] if result.tags: - tags = [tag.strip().lower() for tag in result.tags.split(',')] - + tags = [tag.strip().lower() for tag in result.tags.split(",")] + return HostInfo( id=result.id, hostname=result.hostname, @@ -190,9 +194,9 @@ async def _get_host_info(self, host_id: str) -> Optional[HostInfo]: owner=result.owner, port=result.port or 22, last_scan=result.last_scan.isoformat() if result.last_scan else None, - compliance_score=result.compliance_score + compliance_score=result.compliance_score, ) - + except Exception as e: logger.error(f"Error getting host info for {host_id}: {e}") return None @@ -201,11 +205,11 @@ def _suggest_by_environment(self, host_info: HostInfo) -> Optional[ProfileSugges """Suggest profile based on environment and owner characteristics""" env = host_info.environment.lower() owner = (host_info.owner or "").lower() - + # Check for specific compliance requirements if env in self.ENVIRONMENT_PROFILES: env_profiles = self.ENVIRONMENT_PROFILES[env] - + # Check for federal/government if "federal" in owner or "gov" in owner or "dod" in owner: if "federal" in env_profiles: @@ -217,9 +221,9 @@ def _suggest_by_environment(self, host_info: HostInfo) -> Optional[ProfileSugges reasoning=[f"Federal/government owner detected", f"Environment: {env}"], estimated_duration="15-25 min", rule_count=340, - priority=ScanPriority.HIGH + priority=ScanPriority.HIGH, ) - + # Check for healthcare if any(keyword in owner for keyword in ["health", "medical", "hospital"]): if "healthcare" in env_profiles: @@ -231,9 +235,9 @@ def _suggest_by_environment(self, host_info: HostInfo) -> Optional[ProfileSugges reasoning=["Healthcare organization detected", f"Environment: {env}"], estimated_duration="12-18 min", rule_count=280, - priority=ScanPriority.HIGH + priority=ScanPriority.HIGH, ) - + # Check for financial services if any(keyword in owner for keyword in ["bank", "financial", "payment", "finance"]): if "financial" in env_profiles: @@ -242,12 +246,15 @@ def _suggest_by_environment(self, host_info: HostInfo) -> Optional[ProfileSugges content_id=1, name="PCI DSS Compliance", confidence=0.85, - reasoning=["Financial services organization detected", f"Environment: {env}"], + reasoning=[ + "Financial services organization detected", + f"Environment: {env}", + ], estimated_duration="10-15 min", rule_count=250, - priority=ScanPriority.HIGH + priority=ScanPriority.HIGH, ) - + # Use environment default return ProfileSuggestion( profile_id=env_profiles["default"], @@ -257,9 +264,9 @@ def _suggest_by_environment(self, host_info: HostInfo) -> Optional[ProfileSugges reasoning=[f"Environment-based selection: {env}"], estimated_duration="8-12 min", rule_count=180, - priority=ScanPriority.NORMAL + priority=ScanPriority.NORMAL, ) - + return None def _suggest_by_tags(self, host_info: HostInfo) -> Optional[ProfileSuggestion]: @@ -267,12 +274,12 @@ def _suggest_by_tags(self, host_info: HostInfo) -> Optional[ProfileSuggestion]: for tag in host_info.tags: if tag in self.TAG_PROFILE_MAPPINGS: profile_id = self.TAG_PROFILE_MAPPINGS[tag] - + # Determine priority based on tag criticality priority = ScanPriority.NORMAL if tag in ["database", "payment", "medical", "critical", "dmz"]: priority = ScanPriority.HIGH - + return ProfileSuggestion( profile_id=profile_id, content_id=1, @@ -281,18 +288,18 @@ def _suggest_by_tags(self, host_info: HostInfo) -> Optional[ProfileSuggestion]: reasoning=[f"Host tagged as '{tag}'"], estimated_duration="10-15 min", rule_count=220, - priority=priority + priority=priority, ) - + return None def _suggest_by_owner(self, host_info: HostInfo) -> Optional[ProfileSuggestion]: """Suggest profile based on owner characteristics""" if not host_info.owner: return None - + owner = host_info.owner.lower() - + # Security team hosts get comprehensive scans if any(keyword in owner for keyword in ["security", "infosec", "cyber"]): return ProfileSuggestion( @@ -303,15 +310,15 @@ def _suggest_by_owner(self, host_info: HostInfo) -> Optional[ProfileSuggestion]: reasoning=["Security team ownership detected"], estimated_duration="20-30 min", rule_count=380, - priority=ScanPriority.HIGH + priority=ScanPriority.HIGH, ) - + return None def _suggest_by_os(self, host_info: HostInfo) -> ProfileSuggestion: """Fallback suggestion based on operating system""" os_name = host_info.operating_system.lower() - + # Map OS variants for os_key in self.OS_DEFAULT_PROFILES: if os_key in os_name: @@ -323,9 +330,9 @@ def _suggest_by_os(self, host_info: HostInfo) -> ProfileSuggestion: reasoning=[f"Operating system: {host_info.operating_system}"], estimated_duration="8-12 min", rule_count=160, - priority=ScanPriority.NORMAL + priority=ScanPriority.NORMAL, ) - + # Unknown OS fallback return ProfileSuggestion( profile_id="xccdf_org.ssgproject.content_profile_cui", @@ -335,43 +342,53 @@ def _suggest_by_os(self, host_info: HostInfo) -> ProfileSuggestion: reasoning=["Unknown OS - using universal profile"], estimated_duration="10-15 min", rule_count=180, - priority=ScanPriority.NORMAL + priority=ScanPriority.NORMAL, ) - def _select_best_suggestion(self, suggestions: List[ProfileSuggestion], host_info: HostInfo) -> ProfileSuggestion: + def _select_best_suggestion( + self, suggestions: List[ProfileSuggestion], host_info: HostInfo + ) -> ProfileSuggestion: """Select the best suggestion from multiple options""" if not suggestions: return self._suggest_by_os(host_info) - + # Sort by confidence and priority suggestions.sort(key=lambda s: (s.confidence, s.priority.value == "high"), reverse=True) - + best = suggestions[0] - + # Combine reasoning from top suggestions if they're close in confidence if len(suggestions) > 1 and suggestions[1].confidence >= best.confidence - 0.1: best.reasoning.extend(suggestions[1].reasoning) - + return best - async def _enhance_suggestion_with_content(self, suggestion: ProfileSuggestion) -> ProfileSuggestion: + async def _enhance_suggestion_with_content( + self, suggestion: ProfileSuggestion + ) -> ProfileSuggestion: """Enhance suggestion with actual SCAP content metadata""" try: # Find matching content and profile - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT c.id, c.name, c.profiles FROM scap_content c WHERE c.profiles LIKE :profile_pattern ORDER BY c.created_at DESC LIMIT 1 - """), {"profile_pattern": f"%{suggestion.profile_id}%"}).fetchone() - + """ + ), + {"profile_pattern": f"%{suggestion.profile_id}%"}, + ).fetchone() + if result: suggestion.content_id = result.id - + # Parse profiles to get accurate metadata try: import json + profiles = json.loads(result.profiles or "[]") for profile in profiles: if profile.get("id") == suggestion.profile_id: @@ -380,9 +397,9 @@ async def _enhance_suggestion_with_content(self, suggestion: ProfileSuggestion) break except: pass - + return suggestion - + except Exception as e: logger.warning(f"Failed to enhance suggestion with content metadata: {e}") return suggestion @@ -397,7 +414,7 @@ async def _get_fallback_suggestion(self, host_id: str) -> ProfileSuggestion: reasoning=["Fallback suggestion - analysis failed"], estimated_duration="8-12 min", rule_count=150, - priority=ScanPriority.NORMAL + priority=ScanPriority.NORMAL, ) async def analyze_bulk_scan_feasibility(self, host_ids: List[str]) -> Dict: @@ -409,44 +426,44 @@ async def analyze_bulk_scan_feasibility(self, host_ids: List[str]) -> Dict: host_info = await self._get_host_info(host_id) if host_info: hosts_info.append(host_info) - + if not hosts_info: - return { - "feasible": False, - "reason": "No valid hosts found", - "recommendations": [] - } - + return {"feasible": False, "reason": "No valid hosts found", "recommendations": []} + # Group by OS and environment for batching analysis os_groups = {} env_groups = {} - + for host in hosts_info: os_key = host.operating_system.lower() os_groups.setdefault(os_key, []).append(host) - + env_key = host.environment.lower() env_groups.setdefault(env_key, []).append(host) - + # Calculate estimated time and resource usage total_estimated_time = len(hosts_info) * 10 # Base 10 minutes per host max_parallel = min(5, len(hosts_info)) # Limit concurrent scans actual_time = total_estimated_time / max_parallel - + recommendations = [] - + # OS diversity recommendation if len(os_groups) > 3: recommendations.append("Consider grouping by OS for better content optimization") - + # Environment mixing warning if "production" in env_groups and len(env_groups) > 1: - recommendations.append("Production and non-production hosts mixed - consider separate scans") - + recommendations.append( + "Production and non-production hosts mixed - consider separate scans" + ) + # Large batch warning if len(hosts_info) > 20: - recommendations.append("Large batch detected - consider splitting into smaller groups") - + recommendations.append( + "Large batch detected - consider splitting into smaller groups" + ) + return { "feasible": True, "total_hosts": len(hosts_info), @@ -454,13 +471,13 @@ async def analyze_bulk_scan_feasibility(self, host_ids: List[str]) -> Dict: "max_parallel_scans": max_parallel, "os_groups": {k: len(v) for k, v in os_groups.items()}, "environment_groups": {k: len(v) for k, v in env_groups.items()}, - "recommendations": recommendations + "recommendations": recommendations, } - + except Exception as e: logger.error(f"Error analyzing bulk scan feasibility: {e}") return { "feasible": False, "reason": f"Analysis failed: {str(e)}", - "recommendations": ["Review host selection and try again"] - } \ No newline at end of file + "recommendations": ["Review host selection and try again"], + } diff --git a/backend/app/services/scap_aegis_mapper.py b/backend/app/services/scap_aegis_mapper.py index 3dc8677e..464dc7f4 100644 --- a/backend/app/services/scap_aegis_mapper.py +++ b/backend/app/services/scap_aegis_mapper.py @@ -2,6 +2,7 @@ SCAP to AEGIS Mapper Service Maps SCAP rules to AEGIS remediation actions and manages remediation workflows """ + import logging import json from typing import Dict, List, Optional, Set, Tuple @@ -17,6 +18,7 @@ @dataclass class AEGISMapping: """AEGIS remediation mapping for SCAP rule""" + scap_rule_id: str aegis_rule_id: str rule_category: str # authentication, audit, network, etc. @@ -28,11 +30,12 @@ class AEGISMapping: requires_reboot: bool dependencies: List[str] platforms: List[str] # rhel8, rhel9, ubuntu20, ubuntu22 - + @dataclass class RemediationPlan: """Remediation plan for failed SCAP rules""" + plan_id: str scan_id: str host_id: str @@ -47,48 +50,48 @@ class RemediationPlan: class SCAPAEGISMapper: """Service for mapping SCAP rules to AEGIS remediation actions""" - + def __init__(self, mappings_dir: str = "/app/data/mappings"): self.mappings_dir = Path(mappings_dir) self.mappings_dir.mkdir(parents=True, exist_ok=True) self.rule_mappings = self._load_mappings() self.category_priorities = self._initialize_category_priorities() - + def _load_mappings(self) -> Dict[str, AEGISMapping]: """Load SCAP to AEGIS mappings from configuration""" mappings = {} - + # Load from built-in mappings first mappings.update(self._load_builtin_mappings()) - + # Load from custom mappings directory for mapping_file in self.mappings_dir.glob("*.yml"): try: - with open(mapping_file, 'r') as f: + with open(mapping_file, "r") as f: custom_mappings = yaml.safe_load(f) - + for rule_id, mapping_data in custom_mappings.items(): mappings[rule_id] = AEGISMapping( scap_rule_id=rule_id, - aegis_rule_id=mapping_data.get('aegis_rule_id', ''), - rule_category=mapping_data.get('category', 'system'), - remediation_type=mapping_data.get('type', 'configuration'), - implementation_commands=mapping_data.get('commands', []), - verification_commands=mapping_data.get('verify', []), - rollback_commands=mapping_data.get('rollback', []), - estimated_duration=mapping_data.get('duration', 60), - requires_reboot=mapping_data.get('requires_reboot', False), - dependencies=mapping_data.get('dependencies', []), - platforms=mapping_data.get('platforms', ['rhel8', 'rhel9']) + aegis_rule_id=mapping_data.get("aegis_rule_id", ""), + rule_category=mapping_data.get("category", "system"), + remediation_type=mapping_data.get("type", "configuration"), + implementation_commands=mapping_data.get("commands", []), + verification_commands=mapping_data.get("verify", []), + rollback_commands=mapping_data.get("rollback", []), + estimated_duration=mapping_data.get("duration", 60), + requires_reboot=mapping_data.get("requires_reboot", False), + dependencies=mapping_data.get("dependencies", []), + platforms=mapping_data.get("platforms", ["rhel8", "rhel9"]), ) - + logger.info(f"Loaded {len(custom_mappings)} mappings from {mapping_file}") - + except Exception as e: logger.error(f"Error loading mappings from {mapping_file}: {e}") - + return mappings - + def _load_builtin_mappings(self) -> Dict[str, AEGISMapping]: """Load built-in SCAP to AEGIS mappings""" return { @@ -101,21 +104,18 @@ def _load_builtin_mappings(self) -> Dict[str, AEGISMapping]: implementation_commands=[ "sed -i 's/^#*PermitRootLogin.*/PermitRootLogin no/' /etc/ssh/sshd_config", "grep -q '^PermitRootLogin' /etc/ssh/sshd_config || echo 'PermitRootLogin no' >> /etc/ssh/sshd_config", - "systemctl restart sshd" - ], - verification_commands=[ - "grep -E '^PermitRootLogin\\s+no' /etc/ssh/sshd_config" + "systemctl restart sshd", ], + verification_commands=["grep -E '^PermitRootLogin\\s+no' /etc/ssh/sshd_config"], rollback_commands=[ "sed -i 's/^PermitRootLogin.*/PermitRootLogin yes/' /etc/ssh/sshd_config", - "systemctl restart sshd" + "systemctl restart sshd", ], estimated_duration=30, requires_reboot=False, dependencies=[], - platforms=["rhel8", "rhel9", "ubuntu20", "ubuntu22"] + platforms=["rhel8", "rhel9", "ubuntu20", "ubuntu22"], ), - # Password Policy "xccdf_mil.disa.stig_rule_SV-230365r792936_rule": AEGISMapping( scap_rule_id="xccdf_mil.disa.stig_rule_SV-230365r792936_rule", @@ -124,20 +124,17 @@ def _load_builtin_mappings(self) -> Dict[str, AEGISMapping]: remediation_type="configuration", implementation_commands=[ "sed -i 's/^#*\\s*minlen.*/minlen = 15/' /etc/security/pwquality.conf", - "grep -q '^minlen' /etc/security/pwquality.conf || echo 'minlen = 15' >> /etc/security/pwquality.conf" + "grep -q '^minlen' /etc/security/pwquality.conf || echo 'minlen = 15' >> /etc/security/pwquality.conf", ], verification_commands=[ "grep -E '^minlen\\s*=\\s*(1[5-9]|[2-9][0-9])' /etc/security/pwquality.conf" ], - rollback_commands=[ - "sed -i 's/^minlen.*/minlen = 8/' /etc/security/pwquality.conf" - ], + rollback_commands=["sed -i 's/^minlen.*/minlen = 8/' /etc/security/pwquality.conf"], estimated_duration=20, requires_reboot=False, dependencies=["libpwquality"], - platforms=["rhel8", "rhel9", "ubuntu20", "ubuntu22"] + platforms=["rhel8", "rhel9", "ubuntu20", "ubuntu22"], ), - # Audit Daemon "xccdf_mil.disa.stig_rule_SV-230423r793041_rule": AEGISMapping( scap_rule_id="xccdf_mil.disa.stig_rule_SV-230423r793041_rule", @@ -147,22 +144,15 @@ def _load_builtin_mappings(self) -> Dict[str, AEGISMapping]: implementation_commands=[ "systemctl enable auditd", "systemctl start auditd", - "augenrules --load" - ], - verification_commands=[ - "systemctl is-enabled auditd", - "systemctl is-active auditd" - ], - rollback_commands=[ - "systemctl stop auditd", - "systemctl disable auditd" + "augenrules --load", ], + verification_commands=["systemctl is-enabled auditd", "systemctl is-active auditd"], + rollback_commands=["systemctl stop auditd", "systemctl disable auditd"], estimated_duration=45, requires_reboot=False, dependencies=["audit"], - platforms=["rhel8", "rhel9"] + platforms=["rhel8", "rhel9"], ), - # Firewall Configuration "xccdf_mil.disa.stig_rule_SV-230515r793185_rule": AEGISMapping( scap_rule_id="xccdf_mil.disa.stig_rule_SV-230515r793185_rule", @@ -173,23 +163,19 @@ def _load_builtin_mappings(self) -> Dict[str, AEGISMapping]: "systemctl enable firewalld", "systemctl start firewalld", "firewall-cmd --set-default-zone=public", - "firewall-cmd --reload" + "firewall-cmd --reload", ], verification_commands=[ "systemctl is-enabled firewalld", "systemctl is-active firewalld", - "firewall-cmd --state" - ], - rollback_commands=[ - "systemctl stop firewalld", - "systemctl disable firewalld" + "firewall-cmd --state", ], + rollback_commands=["systemctl stop firewalld", "systemctl disable firewalld"], estimated_duration=60, requires_reboot=False, dependencies=["firewalld"], - platforms=["rhel8", "rhel9"] + platforms=["rhel8", "rhel9"], ), - # File Permissions "xccdf_mil.disa.stig_rule_SV-230279r792861_rule": AEGISMapping( scap_rule_id="xccdf_mil.disa.stig_rule_SV-230279r792861_rule", @@ -200,22 +186,20 @@ def _load_builtin_mappings(self) -> Dict[str, AEGISMapping]: "find /etc -type f -name '*.conf' -exec chmod 644 {} \\;", "find /etc -type d -exec chmod 755 {} \\;", "chmod 600 /etc/shadow", - "chmod 644 /etc/passwd" + "chmod 644 /etc/passwd", ], verification_commands=[ "stat -c '%a' /etc/shadow | grep -q '600'", - "stat -c '%a' /etc/passwd | grep -q '644'" - ], - rollback_commands=[ - "# No rollback for security permissions" + "stat -c '%a' /etc/passwd | grep -q '644'", ], + rollback_commands=["# No rollback for security permissions"], estimated_duration=120, requires_reboot=False, dependencies=[], - platforms=["rhel8", "rhel9", "ubuntu20", "ubuntu22"] - ) + platforms=["rhel8", "rhel9", "ubuntu20", "ubuntu22"], + ), } - + def _initialize_category_priorities(self) -> Dict[str, int]: """Initialize remediation category priorities""" return { @@ -226,52 +210,56 @@ def _initialize_category_priorities(self) -> Dict[str, int]: "system": 5, "service": 6, "permission": 7, - "configuration": 8 # Lowest priority + "configuration": 8, # Lowest priority } - + def get_aegis_mapping(self, scap_rule_id: str) -> Optional[AEGISMapping]: """Get AEGIS mapping for a SCAP rule""" return self.rule_mappings.get(scap_rule_id) - - def create_remediation_plan(self, scan_id: str, host_id: str, - failed_rules: List[Dict[str, str]], - platform: str = "rhel9") -> RemediationPlan: + + def create_remediation_plan( + self, + scan_id: str, + host_id: str, + failed_rules: List[Dict[str, str]], + platform: str = "rhel9", + ) -> RemediationPlan: """Create remediation plan for failed SCAP rules""" try: plan_id = f"plan_{scan_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" logger.info(f"Creating remediation plan {plan_id} for {len(failed_rules)} failed rules") - + # Categorize rules and check for mappings rule_groups = {} remediable_rules = 0 total_duration = 0 requires_reboot = False all_dependencies = set() - + for rule in failed_rules: rule_id = rule.get("rule_id", "") mapping = self.get_aegis_mapping(rule_id) - + if mapping and platform in mapping.platforms: remediable_rules += 1 - + # Group by category if mapping.rule_category not in rule_groups: rule_groups[mapping.rule_category] = [] rule_groups[mapping.rule_category].append(mapping) - + # Accumulate requirements total_duration += mapping.estimated_duration if mapping.requires_reboot: requires_reboot = True all_dependencies.update(mapping.dependencies) - + # Determine execution order based on dependencies and priorities execution_order = self._determine_execution_order(rule_groups, all_dependencies) - + # Check if all dependencies can be resolved dependencies_resolved = self._check_dependencies(all_dependencies, platform) - + plan = RemediationPlan( plan_id=plan_id, scan_id=scan_id, @@ -282,76 +270,78 @@ def create_remediation_plan(self, scan_id: str, host_id: str, requires_reboot=requires_reboot, rule_groups=rule_groups, execution_order=execution_order, - dependencies_resolved=dependencies_resolved + dependencies_resolved=dependencies_resolved, ) - + # Save plan for tracking self._save_remediation_plan(plan) - - logger.info(f"Created remediation plan with {remediable_rules}/{len(failed_rules)} remediable rules") + + logger.info( + f"Created remediation plan with {remediable_rules}/{len(failed_rules)} remediable rules" + ) return plan - + except Exception as e: logger.error(f"Error creating remediation plan: {e}") raise - - def _determine_execution_order(self, rule_groups: Dict[str, List[AEGISMapping]], - dependencies: Set[str]) -> List[str]: + + def _determine_execution_order( + self, rule_groups: Dict[str, List[AEGISMapping]], dependencies: Set[str] + ) -> List[str]: """Determine optimal execution order for remediation""" execution_order = [] - + # Sort categories by priority sorted_categories = sorted( - rule_groups.keys(), - key=lambda x: self.category_priorities.get(x, 999) + rule_groups.keys(), key=lambda x: self.category_priorities.get(x, 999) ) - + # Process rules in priority order for category in sorted_categories: # Within category, sort by dependencies category_rules = rule_groups[category] - + # Rules with no dependencies first no_deps = [r for r in category_rules if not r.dependencies] with_deps = [r for r in category_rules if r.dependencies] - + # Add to execution order for rule in no_deps: execution_order.append(rule.scap_rule_id) - + for rule in with_deps: execution_order.append(rule.scap_rule_id) - + return execution_order - + def _check_dependencies(self, dependencies: Set[str], platform: str) -> bool: """Check if all dependencies can be resolved""" # This would check against package manager or system state # For now, we'll assume common dependencies are available - + common_packages = { "rhel8": ["audit", "firewalld", "libpwquality", "openssh-server"], "rhel9": ["audit", "firewalld", "libpwquality", "openssh-server"], "ubuntu20": ["auditd", "ufw", "libpam-pwquality", "openssh-server"], - "ubuntu22": ["auditd", "ufw", "libpam-pwquality", "openssh-server"] + "ubuntu22": ["auditd", "ufw", "libpam-pwquality", "openssh-server"], } - + platform_packages = set(common_packages.get(platform, [])) - + # Check if all dependencies are in common packages unresolved = dependencies - platform_packages - + if unresolved: logger.warning(f"Unresolved dependencies for {platform}: {unresolved}") return False - + return True - + def _save_remediation_plan(self, plan: RemediationPlan): """Save remediation plan to file""" try: plan_file = self.mappings_dir / f"{plan.plan_id}.json" - + plan_data = { "plan_id": plan.plan_id, "scan_id": plan.scan_id, @@ -363,9 +353,9 @@ def _save_remediation_plan(self, plan: RemediationPlan): "requires_reboot": plan.requires_reboot, "execution_order": plan.execution_order, "dependencies_resolved": plan.dependencies_resolved, - "rule_groups": {} + "rule_groups": {}, } - + # Convert AEGISMapping objects to dicts for category, mappings in plan.rule_groups.items(): plan_data["rule_groups"][category] = [ @@ -373,30 +363,30 @@ def _save_remediation_plan(self, plan: RemediationPlan): "scap_rule_id": m.scap_rule_id, "aegis_rule_id": m.aegis_rule_id, "estimated_duration": m.estimated_duration, - "requires_reboot": m.requires_reboot + "requires_reboot": m.requires_reboot, } for m in mappings ] - - with open(plan_file, 'w') as f: + + with open(plan_file, "w") as f: json.dump(plan_data, f, indent=2) - + logger.info(f"Saved remediation plan to {plan_file}") - + except Exception as e: logger.error(f"Error saving remediation plan: {e}") - + def generate_aegis_job_request(self, plan: RemediationPlan) -> Dict: """Generate AEGIS job request from remediation plan""" try: # Extract all AEGIS rule IDs in execution order aegis_rules = [] - + for scap_rule_id in plan.execution_order: mapping = self.get_aegis_mapping(scap_rule_id) if mapping: aegis_rules.append(mapping.aegis_rule_id) - + # Create AEGIS job request job_request = { "host_id": plan.host_id, @@ -406,7 +396,7 @@ def generate_aegis_job_request(self, plan: RemediationPlan) -> Dict: "force": False, "parallel": False, # Execute in order "continue_on_error": True, - "create_restore_point": True + "create_restore_point": True, }, "metadata": { "source": "openwatch", @@ -414,50 +404,49 @@ def generate_aegis_job_request(self, plan: RemediationPlan) -> Dict: "plan_id": plan.plan_id, "total_rules": plan.remediable_rules, "estimated_duration": plan.estimated_duration, - "requires_reboot": plan.requires_reboot - } + "requires_reboot": plan.requires_reboot, + }, } - + return job_request - + except Exception as e: logger.error(f"Error generating AEGIS job request: {e}") raise - - def map_aegis_results_to_scap(self, aegis_job_id: str, - aegis_results: Dict) -> Dict[str, str]: + + def map_aegis_results_to_scap(self, aegis_job_id: str, aegis_results: Dict) -> Dict[str, str]: """Map AEGIS remediation results back to SCAP rules""" try: scap_results = {} - + # Get job executions from AEGIS results executions = aegis_results.get("executions", []) - + for execution in executions: aegis_rule_id = execution.get("rule_id", "") status = execution.get("status", "unknown") - + # Find corresponding SCAP rule for scap_id, mapping in self.rule_mappings.items(): if mapping.aegis_rule_id == aegis_rule_id: scap_results[scap_id] = "pass" if status == "completed" else "fail" break - + return scap_results - + except Exception as e: logger.error(f"Error mapping AEGIS results: {e}") return {} - + def get_manual_remediation_steps(self, scap_rule_id: str) -> Optional[Dict]: """Get manual remediation steps for rules without AEGIS mapping""" try: # Check if we have a mapping mapping = self.get_aegis_mapping(scap_rule_id) - + if not mapping: return None - + return { "rule_id": scap_rule_id, "category": mapping.rule_category, @@ -465,7 +454,7 @@ def get_manual_remediation_steps(self, scap_rule_id: str) -> Optional[Dict]: { "description": f"Execute command: {cmd}", "command": cmd, - "type": "implementation" + "type": "implementation", } for cmd in mapping.implementation_commands ], @@ -473,28 +462,29 @@ def get_manual_remediation_steps(self, scap_rule_id: str) -> Optional[Dict]: { "description": f"Verify with: {cmd}", "command": cmd, - "expected_result": "Command should return 0 exit code" + "expected_result": "Command should return 0 exit code", } for cmd in mapping.verification_commands ], - "rollback": [ - { - "description": f"Rollback with: {cmd}", - "command": cmd - } - for cmd in mapping.rollback_commands - ] if mapping.rollback_commands else None + "rollback": ( + [ + {"description": f"Rollback with: {cmd}", "command": cmd} + for cmd in mapping.rollback_commands + ] + if mapping.rollback_commands + else None + ), } - + except Exception as e: logger.error(f"Error getting manual remediation steps: {e}") return None - + def export_mappings(self, format: str = "yaml") -> str: """Export all SCAP to AEGIS mappings""" try: export_data = {} - + for scap_id, mapping in self.rule_mappings.items(): export_data[scap_id] = { "aegis_rule_id": mapping.aegis_rule_id, @@ -506,16 +496,16 @@ def export_mappings(self, format: str = "yaml") -> str: "duration": mapping.estimated_duration, "requires_reboot": mapping.requires_reboot, "dependencies": mapping.dependencies, - "platforms": mapping.platforms + "platforms": mapping.platforms, } - + if format == "yaml": return yaml.dump(export_data, default_flow_style=False) elif format == "json": return json.dumps(export_data, indent=2) else: raise ValueError(f"Unsupported format: {format}") - + except Exception as e: logger.error(f"Error exporting mappings: {e}") - raise \ No newline at end of file + raise diff --git a/backend/app/services/scap_cli_scanner.py b/backend/app/services/scap_cli_scanner.py index 24e2f0e3..65276d51 100644 --- a/backend/app/services/scap_cli_scanner.py +++ b/backend/app/services/scap_cli_scanner.py @@ -2,6 +2,7 @@ OpenWatch CLI SCAP Scanner Service Enhanced SCAP scanning engine for command-line operations supporting 100+ parallel hosts """ + import os import asyncio import concurrent.futures @@ -19,47 +20,52 @@ class CLIScannerError(Exception): """Exception raised for CLI scanner specific errors""" + pass class SCAPCLIScanner: """Enhanced SCAP scanner optimized for CLI operations and parallel scanning""" - - def __init__(self, content_dir: str = "/app/data/scap", - results_dir: str = "/app/data/results", - max_parallel_scans: int = 100): + + def __init__( + self, + content_dir: str = "/app/data/scap", + results_dir: str = "/app/data/results", + max_parallel_scans: int = 100, + ): self.base_scanner = SCAPScanner(content_dir, results_dir) self.content_dir = Path(content_dir) self.results_dir = Path(results_dir) self.max_parallel_scans = max_parallel_scans - + # Ensure directories exist self.content_dir.mkdir(parents=True, exist_ok=True) self.results_dir.mkdir(parents=True, exist_ok=True) - - async def scan_single_host(self, host_config: Dict, profile_id: str, - content_path: str, rule_id: str = None) -> Dict: + + async def scan_single_host( + self, host_config: Dict, profile_id: str, content_path: str, rule_id: str = None + ) -> Dict: """ Scan a single host with given configuration - + Args: host_config: Host configuration dict with keys: hostname, port, username, auth_method, credential profile_id: SCAP profile identifier content_path: Path to SCAP content file rule_id: Optional specific rule to scan - + Returns: Scan results dictionary """ try: scan_id = str(uuid.uuid4()) - + logger.info(f"Starting scan {scan_id} for {host_config.get('hostname', 'localhost')}") - + # Determine if this is a local or remote scan - hostname = host_config.get('hostname', 'localhost') - - if hostname in ['localhost', '127.0.0.1', '::1']: + hostname = host_config.get("hostname", "localhost") + + if hostname in ["localhost", "127.0.0.1", "::1"]: # Local scan return await self._execute_local_scan_async( content_path, profile_id, scan_id, rule_id @@ -67,117 +73,136 @@ async def scan_single_host(self, host_config: Dict, profile_id: str, else: # Remote scan return await self._execute_remote_scan_async( - hostname, - host_config.get('port', 22), - host_config.get('username', 'root'), - host_config.get('auth_method', 'password'), - host_config.get('credential', ''), - content_path, - profile_id, - scan_id, - rule_id + hostname, + host_config.get("port", 22), + host_config.get("username", "root"), + host_config.get("auth_method", "password"), + host_config.get("credential", ""), + content_path, + profile_id, + scan_id, + rule_id, ) - + except Exception as e: logger.error(f"Scan failed for {host_config.get('hostname', 'unknown')}: {e}") return { - "scan_id": scan_id if 'scan_id' in locals() else str(uuid.uuid4()), - "hostname": host_config.get('hostname', 'unknown'), + "scan_id": scan_id if "scan_id" in locals() else str(uuid.uuid4()), + "hostname": host_config.get("hostname", "unknown"), "status": "error", "error": str(e), - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } - - async def scan_multiple_hosts(self, hosts_configs: List[Dict], profile_id: str, - content_path: str, rule_id: str = None, - progress_callback = None) -> AsyncGenerator[Dict, None]: + + async def scan_multiple_hosts( + self, + hosts_configs: List[Dict], + profile_id: str, + content_path: str, + rule_id: str = None, + progress_callback=None, + ) -> AsyncGenerator[Dict, None]: """ Scan multiple hosts in parallel with configurable concurrency - + Args: hosts_configs: List of host configuration dictionaries - profile_id: SCAP profile identifier + profile_id: SCAP profile identifier content_path: Path to SCAP content file rule_id: Optional specific rule to scan progress_callback: Optional callback function for progress updates - + Yields: Individual scan results as they complete """ total_hosts = len(hosts_configs) completed_scans = 0 - - logger.info(f"Starting parallel scan of {total_hosts} hosts (max parallel: {self.max_parallel_scans})") - + + logger.info( + f"Starting parallel scan of {total_hosts} hosts (max parallel: {self.max_parallel_scans})" + ) + # Create semaphore to limit concurrent scans semaphore = asyncio.Semaphore(self.max_parallel_scans) - + async def scan_with_semaphore(host_config: Dict) -> Dict: async with semaphore: return await self.scan_single_host(host_config, profile_id, content_path, rule_id) - + # Create all scan tasks - scan_tasks = [ - scan_with_semaphore(host_config) - for host_config in hosts_configs - ] - + scan_tasks = [scan_with_semaphore(host_config) for host_config in hosts_configs] + # Process results as they complete for coro in asyncio.as_completed(scan_tasks): result = await coro completed_scans += 1 - + # Call progress callback if provided if progress_callback: progress_callback(completed_scans, total_hosts, result) - + # Log progress logger.info(f"Scan progress: {completed_scans}/{total_hosts} completed") - + yield result - - async def batch_scan_from_targets(self, targets: List[str], profile_id: str, - content_path: str, rule_id: str = None, - default_credentials: Dict = None) -> List[Dict]: + + async def batch_scan_from_targets( + self, + targets: List[str], + profile_id: str, + content_path: str, + rule_id: str = None, + default_credentials: Dict = None, + ) -> List[Dict]: """ Perform batch scan from a list of target hostnames/IPs - + Args: targets: List of hostnames or IP addresses profile_id: SCAP profile identifier - content_path: Path to SCAP content file + content_path: Path to SCAP content file rule_id: Optional specific rule to scan default_credentials: Default SSH credentials to use - + Returns: List of all scan results """ # Convert targets to host configs hosts_configs = [] - + for target in targets: host_config = { - 'hostname': target, - 'port': 22, - 'username': default_credentials.get('username', 'root') if default_credentials else 'root', - 'auth_method': default_credentials.get('auth_method', 'password') if default_credentials else 'password', - 'credential': default_credentials.get('credential', '') if default_credentials else '' + "hostname": target, + "port": 22, + "username": ( + default_credentials.get("username", "root") if default_credentials else "root" + ), + "auth_method": ( + default_credentials.get("auth_method", "password") + if default_credentials + else "password" + ), + "credential": ( + default_credentials.get("credential", "") if default_credentials else "" + ), } hosts_configs.append(host_config) - + # Collect all results results = [] - + def progress_callback(completed, total, result): - print(f"[OpenWatch] Progress: {completed}/{total} - {result.get('hostname', 'unknown')} completed") - + print( + f"[OpenWatch] Progress: {completed}/{total} - {result.get('hostname', 'unknown')} completed" + ) + async for result in self.scan_multiple_hosts( hosts_configs, profile_id, content_path, rule_id, progress_callback ): results.append(result) - + return results - + def get_available_profiles(self, content_path: str) -> List[Dict]: """Get available SCAP profiles from content file""" try: @@ -185,7 +210,7 @@ def get_available_profiles(self, content_path: str) -> List[Dict]: except Exception as e: logger.error(f"Failed to extract profiles: {e}") return [] - + def validate_content_file(self, content_path: str) -> bool: """Validate SCAP content file""" try: @@ -193,129 +218,155 @@ def validate_content_file(self, content_path: str) -> bool: return True except SCAPContentError: return False - + def get_default_content_path(self) -> str: """Get path to default SCAP content file""" # Look for common SCAP content files potential_files = [ self.content_dir / "ssg-rhel8-ds.xml", - self.content_dir / "ssg-ubuntu2004-ds.xml", + self.content_dir / "ssg-ubuntu2004-ds.xml", self.content_dir / "default-content.xml", - self.content_dir / "scap-content.xml" + self.content_dir / "scap-content.xml", ] - + for file_path in potential_files: if file_path.exists(): logger.info(f"Using default content file: {file_path}") return str(file_path) - + # If no content found, log warning logger.warning("No default SCAP content file found") return str(self.content_dir / "default-content.xml") - - async def _execute_local_scan_async(self, content_path: str, profile_id: str, - scan_id: str, rule_id: str = None) -> Dict: + + async def _execute_local_scan_async( + self, content_path: str, profile_id: str, scan_id: str, rule_id: str = None + ) -> Dict: """Execute local scan asynchronously""" loop = asyncio.get_event_loop() - + # Run the blocking scan in a thread pool with concurrent.futures.ThreadPoolExecutor() as executor: result = await loop.run_in_executor( executor, self.base_scanner.execute_local_scan, - content_path, profile_id, scan_id, rule_id + content_path, + profile_id, + scan_id, + rule_id, ) - + # Add CLI-specific metadata - result.update({ - "hostname": "localhost", - "status": "completed" if result.get("exit_code") == 0 else "failed", - "cli_scan": True - }) - + result.update( + { + "hostname": "localhost", + "status": "completed" if result.get("exit_code") == 0 else "failed", + "cli_scan": True, + } + ) + return result - - async def _execute_remote_scan_async(self, hostname: str, port: int, username: str, - auth_method: str, credential: str, content_path: str, - profile_id: str, scan_id: str, rule_id: str = None) -> Dict: + + async def _execute_remote_scan_async( + self, + hostname: str, + port: int, + username: str, + auth_method: str, + credential: str, + content_path: str, + profile_id: str, + scan_id: str, + rule_id: str = None, + ) -> Dict: """Execute remote scan asynchronously""" loop = asyncio.get_event_loop() - + # Run the blocking scan in a thread pool with concurrent.futures.ThreadPoolExecutor() as executor: result = await loop.run_in_executor( executor, self.base_scanner.execute_remote_scan, - hostname, port, username, auth_method, credential, - content_path, profile_id, scan_id, rule_id + hostname, + port, + username, + auth_method, + credential, + content_path, + profile_id, + scan_id, + rule_id, ) - + # Add CLI-specific metadata - result.update({ - "hostname": hostname, - "status": "completed" if result.get("exit_code") == 0 else "failed", - "cli_scan": True - }) - + result.update( + { + "hostname": hostname, + "status": "completed" if result.get("exit_code") == 0 else "failed", + "cli_scan": True, + } + ) + return result - + def generate_scan_summary(self, results: List[Dict]) -> Dict: """Generate summary statistics from scan results""" if not results: return {"error": "No scan results provided"} - + total_hosts = len(results) successful_scans = len([r for r in results if r.get("status") == "completed"]) failed_scans = len([r for r in results if r.get("status") == "failed"]) error_scans = len([r for r in results if r.get("status") == "error"]) - + # Aggregate rule statistics total_rules = sum(r.get("rules_total", 0) for r in results if "rules_total" in r) total_passed = sum(r.get("rules_passed", 0) for r in results if "rules_passed" in r) total_failed = sum(r.get("rules_failed", 0) for r in results if "rules_failed" in r) - + # Calculate average score scores = [r.get("score", 0) for r in results if "score" in r and r.get("score") is not None] avg_score = sum(scores) / len(scores) if scores else 0 - + return { "scan_summary": { "total_hosts": total_hosts, "successful_scans": successful_scans, "failed_scans": failed_scans, "error_scans": error_scans, - "success_rate": (successful_scans / total_hosts * 100) if total_hosts > 0 else 0 + "success_rate": (successful_scans / total_hosts * 100) if total_hosts > 0 else 0, }, "compliance_summary": { "total_rules_checked": total_rules, "total_rules_passed": total_passed, "total_rules_failed": total_failed, "average_compliance_score": avg_score, - "overall_compliance_rate": (total_passed / total_rules * 100) if total_rules > 0 else 0 + "overall_compliance_rate": ( + (total_passed / total_rules * 100) if total_rules > 0 else 0 + ), }, - "timestamp": datetime.now().isoformat() + "timestamp": datetime.now().isoformat(), } - + def export_results_json(self, results: List[Dict], output_file: str) -> bool: """Export scan results to JSON file""" try: output_path = Path(output_file) output_path.parent.mkdir(parents=True, exist_ok=True) - + # Include summary in export export_data = { "scan_results": results, "summary": self.generate_scan_summary(results), "exported_at": datetime.now().isoformat(), - "total_scans": len(results) + "total_scans": len(results), } - - with open(output_path, 'w') as f: + + with open(output_path, "w") as f: json.dump(export_data, f, indent=2) - + logger.info(f"Scan results exported to: {output_path}") return True - + except Exception as e: logger.error(f"Failed to export results: {e}") - return False \ No newline at end of file + return False diff --git a/backend/app/services/scap_datastream_processor.py b/backend/app/services/scap_datastream_processor.py index 1cd6c9a8..bf24f3d3 100644 --- a/backend/app/services/scap_datastream_processor.py +++ b/backend/app/services/scap_datastream_processor.py @@ -2,6 +2,7 @@ SCAP Data-Stream Processor Service Handles modern SCAP data-stream format processing with profile extraction """ + import os import subprocess import tempfile @@ -20,73 +21,77 @@ class DataStreamError(Exception): """Exception for data-stream processing errors""" + pass class SCAPDataStreamProcessor: """Process SCAP data-stream format content""" - + def __init__(self, content_dir: str = "/app/data/scap"): self.content_dir = Path(content_dir) self.content_dir.mkdir(parents=True, exist_ok=True) - + # Namespaces for SCAP data-stream self.namespaces = { - 'ds': 'http://scap.nist.gov/schema/scap/source/1.2', - 'xccdf': 'http://checklists.nist.gov/xccdf/1.2', - 'cpe': 'http://cpe.mitre.org/language/2.0', - 'oval': 'http://oval.mitre.org/XMLSchema/oval-definitions-5', - 'xlink': 'http://www.w3.org/1999/xlink' + "ds": "http://scap.nist.gov/schema/scap/source/1.2", + "xccdf": "http://checklists.nist.gov/xccdf/1.2", + "cpe": "http://cpe.mitre.org/language/2.0", + "oval": "http://oval.mitre.org/XMLSchema/oval-definitions-5", + "xlink": "http://www.w3.org/1999/xlink", } - + def validate_datastream(self, file_path: str) -> Dict: """Validate SCAP data-stream file and extract metadata""" try: logger.info(f"Validating SCAP data-stream: {file_path}") - + # First check if it's a ZIP file (common for DISA distributions) if zipfile.is_zipfile(file_path): return self._process_zip_content(file_path) - + # Use oscap to validate data-stream - result = subprocess.run([ - 'oscap', 'ds', 'sds-validate', file_path - ], capture_output=True, text=True, timeout=30) - + result = subprocess.run( + ["oscap", "ds", "sds-validate", file_path], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: # Try as XCCDF file if data-stream validation fails return self._validate_xccdf_file(file_path) - + # Extract data-stream info - info_result = subprocess.run([ - 'oscap', 'info', file_path - ], capture_output=True, text=True, timeout=30) - + info_result = subprocess.run( + ["oscap", "info", file_path], capture_output=True, text=True, timeout=30 + ) + if info_result.returncode != 0: raise DataStreamError(f"Failed to extract info: {info_result.stderr}") - + metadata = self._parse_oscap_info(info_result.stdout) - metadata['format'] = 'data-stream' - metadata['validation_status'] = 'valid' - + metadata["format"] = "data-stream" + metadata["validation_status"] = "valid" + # Extract additional metadata from XML xml_metadata = self._extract_xml_metadata(file_path) metadata.update(xml_metadata) - + logger.info(f"Data-stream validated successfully: {metadata.get('title', 'Unknown')}") return metadata - + except subprocess.TimeoutExpired: raise DataStreamError("Timeout validating data-stream") except Exception as e: logger.error(f"Error validating data-stream: {e}") raise DataStreamError(f"Validation failed: {str(e)}") - + def extract_profiles_with_metadata(self, file_path: str) -> List[Dict]: """Extract profiles with full metadata using oscap info --profiles""" try: logger.info(f"Extracting profiles from: {file_path}") - + # Handle ZIP files if zipfile.is_zipfile(file_path): with tempfile.TemporaryDirectory() as temp_dir: @@ -95,194 +100,206 @@ def extract_profiles_with_metadata(self, file_path: str) -> List[Dict]: return self.extract_profiles_with_metadata(extracted_file) else: return [] - + # Use oscap info --profiles for detailed profile information - result = subprocess.run([ - 'oscap', 'info', '--profiles', file_path - ], capture_output=True, text=True, timeout=30) - + result = subprocess.run( + ["oscap", "info", "--profiles", file_path], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: logger.warning(f"Failed to extract profiles: {result.stderr}") return [] - + profiles = self._parse_detailed_profiles(result.stdout) - + # Enhance with XML parsing for additional metadata enhanced_profiles = self._enhance_profiles_from_xml(file_path, profiles) - + logger.info(f"Extracted {len(enhanced_profiles)} profiles with metadata") return enhanced_profiles - + except Exception as e: logger.error(f"Error extracting profiles: {e}") return [] - + def extract_content_components(self, file_path: str) -> Dict: """Extract all components from SCAP content (data-streams, benchmarks, checks)""" try: components = { - 'data_streams': [], - 'benchmarks': [], - 'profiles': [], - 'cpe_lists': [], - 'oval_definitions': [], - 'rules': [] + "data_streams": [], + "benchmarks": [], + "profiles": [], + "cpe_lists": [], + "oval_definitions": [], + "rules": [], } - + # Parse XML to extract components tree = etree.parse(file_path) root = tree.getroot() - + # Check if it's a data-stream collection - if root.tag.endswith('data-stream-collection'): - components['format'] = 'data-stream-collection' - components['data_streams'] = self._extract_datastreams(root) - elif root.tag.endswith('Benchmark'): - components['format'] = 'xccdf-benchmark' - components['benchmarks'] = [self._extract_benchmark_info(root)] + if root.tag.endswith("data-stream-collection"): + components["format"] = "data-stream-collection" + components["data_streams"] = self._extract_datastreams(root) + elif root.tag.endswith("Benchmark"): + components["format"] = "xccdf-benchmark" + components["benchmarks"] = [self._extract_benchmark_info(root)] else: - components['format'] = 'unknown' - + components["format"] = "unknown" + # Extract profiles - components['profiles'] = self._extract_profiles_from_tree(root) - + components["profiles"] = self._extract_profiles_from_tree(root) + # Extract rules with metadata - components['rules'] = self._extract_rules_with_metadata(root) - + components["rules"] = self._extract_rules_with_metadata(root) + # Extract CPE and OVAL references - components['cpe_lists'] = self._extract_cpe_references(root) - components['oval_definitions'] = self._extract_oval_references(root) - + components["cpe_lists"] = self._extract_cpe_references(root) + components["oval_definitions"] = self._extract_oval_references(root) + return components - + except Exception as e: logger.error(f"Error extracting content components: {e}") raise DataStreamError(f"Failed to extract components: {str(e)}") - + def create_content_validation_report(self, file_path: str) -> Dict: """Create comprehensive validation report for SCAP content""" report = { - 'file_path': file_path, - 'timestamp': datetime.now().isoformat(), - 'validation_status': 'unknown', - 'errors': [], - 'warnings': [], - 'info': {}, - 'recommendations': [] + "file_path": file_path, + "timestamp": datetime.now().isoformat(), + "validation_status": "unknown", + "errors": [], + "warnings": [], + "info": {}, + "recommendations": [], } - + try: # Basic file checks file_stats = os.stat(file_path) - report['info']['file_size'] = file_stats.st_size - report['info']['file_hash'] = self._calculate_file_hash(file_path) - + report["info"]["file_size"] = file_stats.st_size + report["info"]["file_hash"] = self._calculate_file_hash(file_path) + # Validate with oscap - validation_result = subprocess.run([ - 'oscap', 'ds', 'sds-validate', file_path - ], capture_output=True, text=True, timeout=60) - + validation_result = subprocess.run( + ["oscap", "ds", "sds-validate", file_path], + capture_output=True, + text=True, + timeout=60, + ) + if validation_result.returncode == 0: - report['validation_status'] = 'valid_datastream' + report["validation_status"] = "valid_datastream" else: # Try XCCDF validation - xccdf_result = subprocess.run([ - 'oscap', 'xccdf', 'validate', file_path - ], capture_output=True, text=True, timeout=60) - + xccdf_result = subprocess.run( + ["oscap", "xccdf", "validate", file_path], + capture_output=True, + text=True, + timeout=60, + ) + if xccdf_result.returncode == 0: - report['validation_status'] = 'valid_xccdf' + report["validation_status"] = "valid_xccdf" else: - report['validation_status'] = 'invalid' - report['errors'].append(validation_result.stderr) - + report["validation_status"] = "invalid" + report["errors"].append(validation_result.stderr) + # Extract content info - info_result = subprocess.run([ - 'oscap', 'info', file_path - ], capture_output=True, text=True, timeout=30) - + info_result = subprocess.run( + ["oscap", "info", file_path], capture_output=True, text=True, timeout=30 + ) + if info_result.returncode == 0: - report['info']['content_metadata'] = self._parse_oscap_info(info_result.stdout) - + report["info"]["content_metadata"] = self._parse_oscap_info(info_result.stdout) + # Check for common issues self._check_common_issues(file_path, report) - + # Generate recommendations self._generate_recommendations(report) - + return report - + except Exception as e: - report['validation_status'] = 'error' - report['errors'].append(f"Validation error: {str(e)}") + report["validation_status"] = "error" + report["errors"].append(f"Validation error: {str(e)}") return report - + def _process_zip_content(self, zip_path: str) -> Dict: """Process SCAP content from ZIP file""" try: with tempfile.TemporaryDirectory() as temp_dir: - with zipfile.ZipFile(zip_path, 'r') as zip_file: + with zipfile.ZipFile(zip_path, "r") as zip_file: # Extract all files zip_file.extractall(temp_dir) - + # Find SCAP content files scap_files = [] for root, dirs, files in os.walk(temp_dir): for file in files: - if file.endswith(('.xml', '.scap')): + if file.endswith((".xml", ".scap")): full_path = os.path.join(root, file) # Skip small files (likely metadata) if os.path.getsize(full_path) > 1000: scap_files.append(full_path) - + if not scap_files: raise DataStreamError("No SCAP content found in ZIP file") - + # Process the main SCAP file (usually the largest) main_file = max(scap_files, key=os.path.getsize) - + # Validate the extracted file metadata = self.validate_datastream(main_file) - metadata['source_format'] = 'zip' - metadata['extracted_from'] = os.path.basename(zip_path) - + metadata["source_format"] = "zip" + metadata["extracted_from"] = os.path.basename(zip_path) + return metadata - + except Exception as e: logger.error(f"Error processing ZIP content: {e}") raise DataStreamError(f"Failed to process ZIP: {str(e)}") - + def _validate_xccdf_file(self, file_path: str) -> Dict: """Validate as XCCDF file if not a data-stream""" try: - result = subprocess.run([ - 'oscap', 'xccdf', 'validate', file_path - ], capture_output=True, text=True, timeout=30) - + result = subprocess.run( + ["oscap", "xccdf", "validate", file_path], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: raise DataStreamError(f"Invalid XCCDF content: {result.stderr}") - + # Extract XCCDF info - info_result = subprocess.run([ - 'oscap', 'info', file_path - ], capture_output=True, text=True, timeout=30) - + info_result = subprocess.run( + ["oscap", "info", file_path], capture_output=True, text=True, timeout=30 + ) + metadata = self._parse_oscap_info(info_result.stdout) - metadata['format'] = 'xccdf' - metadata['validation_status'] = 'valid' - + metadata["format"] = "xccdf" + metadata["validation_status"] = "valid" + return metadata - + except Exception as e: raise DataStreamError(f"XCCDF validation failed: {str(e)}") - + def _extract_scap_from_zip(self, zip_path: str, extract_dir: str) -> Optional[str]: """Extract SCAP content file from ZIP""" try: - with zipfile.ZipFile(zip_path, 'r') as zip_file: + with zipfile.ZipFile(zip_path, "r") as zip_file: # Look for SCAP data-stream or XCCDF files - scap_patterns = ['-scap_', '_datastream', '-xccdf', '.xml', '.scap'] - + scap_patterns = ["-scap_", "_datastream", "-xccdf", ".xml", ".scap"] + for file_info in zip_file.filelist: filename = file_info.filename.lower() if any(pattern in filename for pattern in scap_patterns): @@ -290,343 +307,356 @@ def _extract_scap_from_zip(self, zip_path: str, extract_dir: str) -> Optional[st if not file_info.is_dir() and file_info.file_size > 1000: extracted_path = zip_file.extract(file_info, extract_dir) return extracted_path - + return None - + except Exception as e: logger.error(f"Error extracting from ZIP: {e}") return None - + def _parse_oscap_info(self, info_output: str) -> Dict: """Parse oscap info command output""" info = {} - lines = info_output.split('\n') - + lines = info_output.split("\n") + for line in lines: line = line.strip() - if ':' in line: - key, value = line.split(':', 1) - key = key.strip().lower().replace(' ', '_').replace('-', '_') + if ":" in line: + key, value = line.split(":", 1) + key = key.strip().lower().replace(" ", "_").replace("-", "_") value = value.strip() - + # Handle special cases - if key == 'profiles': + if key == "profiles": continue # Profiles are parsed separately - elif key == 'referenced_check_files': - info[key] = [v.strip() for v in value.split(',') if v.strip()] + elif key == "referenced_check_files": + info[key] = [v.strip() for v in value.split(",") if v.strip()] else: info[key] = value - + return info - + def _parse_detailed_profiles(self, profiles_output: str) -> List[Dict]: """Parse detailed profiles from oscap info --profiles output""" profiles = [] current_profile = None - - lines = profiles_output.split('\n') - + + lines = profiles_output.split("\n") + for line in lines: line = line.strip() - - if line.startswith('Profile:'): + + if line.startswith("Profile:"): # Save previous profile if exists if current_profile: profiles.append(current_profile) - + # Extract profile ID (format: "Profile: profile_id") - profile_id = line.split(':', 1)[1].strip() + profile_id = line.split(":", 1)[1].strip() current_profile = { - 'id': profile_id, - 'title': '', - 'description': '', - 'extends': None, - 'selected_rules': [], - 'metadata': {} + "id": profile_id, + "title": "", + "description": "", + "extends": None, + "selected_rules": [], + "metadata": {}, } - - elif line.startswith('Title:') and current_profile: - current_profile['title'] = line.split(':', 1)[1].strip() - - elif line.startswith('Description:') and current_profile: + + elif line.startswith("Title:") and current_profile: + current_profile["title"] = line.split(":", 1)[1].strip() + + elif line.startswith("Description:") and current_profile: # Description might span multiple lines - desc_start = line.split(':', 1)[1].strip() - current_profile['description'] = desc_start - - elif line.startswith('Extends:') and current_profile: - current_profile['extends'] = line.split(':', 1)[1].strip() - - elif line and current_profile and not any(line.startswith(prefix) - for prefix in ['Profile:', 'Title:', 'Description:', 'Extends:']): + desc_start = line.split(":", 1)[1].strip() + current_profile["description"] = desc_start + + elif line.startswith("Extends:") and current_profile: + current_profile["extends"] = line.split(":", 1)[1].strip() + + elif ( + line + and current_profile + and not any( + line.startswith(prefix) + for prefix in ["Profile:", "Title:", "Description:", "Extends:"] + ) + ): # Continue description if no new field - if current_profile['description']: - current_profile['description'] += ' ' + line - + if current_profile["description"]: + current_profile["description"] += " " + line + # Don't forget the last profile if current_profile: profiles.append(current_profile) - + return profiles - + def _enhance_profiles_from_xml(self, file_path: str, profiles: List[Dict]) -> List[Dict]: """Enhance profile information by parsing XML directly""" try: tree = etree.parse(file_path) root = tree.getroot() - + # Create profile lookup - profile_lookup = {p['id']: p for p in profiles} - + profile_lookup = {p["id"]: p for p in profiles} + # Find all Profile elements - profile_elements = root.xpath('.//xccdf:Profile', namespaces=self.namespaces) - + profile_elements = root.xpath(".//xccdf:Profile", namespaces=self.namespaces) + for profile_elem in profile_elements: - profile_id = profile_elem.get('id', '') - + profile_id = profile_elem.get("id", "") + if profile_id in profile_lookup: profile = profile_lookup[profile_id] - + # Extract additional metadata - profile['metadata']['severity'] = profile_elem.get('severity', 'unknown') - + profile["metadata"]["severity"] = profile_elem.get("severity", "unknown") + # Extract platform information - platforms = profile_elem.xpath('.//xccdf:platform', namespaces=self.namespaces) - profile['metadata']['platforms'] = [p.get('idref', '') for p in platforms] - + platforms = profile_elem.xpath(".//xccdf:platform", namespaces=self.namespaces) + profile["metadata"]["platforms"] = [p.get("idref", "") for p in platforms] + # Count selected rules - selections = profile_elem.xpath('.//xccdf:select', namespaces=self.namespaces) - profile['metadata']['rule_count'] = len([s for s in selections if s.get('selected') == 'true']) - + selections = profile_elem.xpath(".//xccdf:select", namespaces=self.namespaces) + profile["metadata"]["rule_count"] = len( + [s for s in selections if s.get("selected") == "true"] + ) + # Extract profile notes or remarks - remarks = profile_elem.xpath('.//xccdf:remark', namespaces=self.namespaces) + remarks = profile_elem.xpath(".//xccdf:remark", namespaces=self.namespaces) if remarks: - profile['metadata']['remarks'] = [r.text for r in remarks if r.text] + profile["metadata"]["remarks"] = [r.text for r in remarks if r.text] else: # Profile found in XML but not in oscap output new_profile = self._extract_profile_from_element(profile_elem) profiles.append(new_profile) - + return profiles - + except Exception as e: logger.warning(f"Could not enhance profiles from XML: {e}") return profiles - + def _extract_profile_from_element(self, profile_elem) -> Dict: """Extract profile information from XML element""" profile = { - 'id': profile_elem.get('id', ''), - 'title': '', - 'description': '', - 'extends': profile_elem.get('extends', None), - 'selected_rules': [], - 'metadata': {} + "id": profile_elem.get("id", ""), + "title": "", + "description": "", + "extends": profile_elem.get("extends", None), + "selected_rules": [], + "metadata": {}, } - + # Extract title - title_elem = profile_elem.find('xccdf:title', self.namespaces) + title_elem = profile_elem.find("xccdf:title", self.namespaces) if title_elem is not None and title_elem.text: - profile['title'] = title_elem.text - + profile["title"] = title_elem.text + # Extract description - desc_elem = profile_elem.find('xccdf:description', self.namespaces) + desc_elem = profile_elem.find("xccdf:description", self.namespaces) if desc_elem is not None: - profile['description'] = self._extract_text_content(desc_elem) - + profile["description"] = self._extract_text_content(desc_elem) + # Extract selected rules - selections = profile_elem.xpath('.//xccdf:select[@selected="true"]', - namespaces=self.namespaces) - profile['selected_rules'] = [s.get('idref', '') for s in selections] - + selections = profile_elem.xpath( + './/xccdf:select[@selected="true"]', namespaces=self.namespaces + ) + profile["selected_rules"] = [s.get("idref", "") for s in selections] + return profile - + def _extract_text_content(self, element) -> str: """Extract clean text content from XML element""" if element is None: return "" - + # Get text content, handling HTML tags - text = etree.tostring(element, method='text', encoding='unicode').strip() - + text = etree.tostring(element, method="text", encoding="unicode").strip() + # Clean up whitespace import re - text = re.sub(r'\s+', ' ', text).strip() - + + text = re.sub(r"\s+", " ", text).strip() + return text - + def _extract_xml_metadata(self, file_path: str) -> Dict: """Extract additional metadata from XML structure""" metadata = {} - + try: tree = etree.parse(file_path) root = tree.getroot() - + # Determine content type - if root.tag.endswith('data-stream-collection'): - metadata['content_type'] = 'SCAP Data Stream Collection' - metadata['scap_version'] = root.get('schematron-version', '1.2') - + if root.tag.endswith("data-stream-collection"): + metadata["content_type"] = "SCAP Data Stream Collection" + metadata["scap_version"] = root.get("schematron-version", "1.2") + # Count data streams - streams = root.xpath('.//ds:data-stream', namespaces=self.namespaces) - metadata['data_stream_count'] = len(streams) - - elif root.tag.endswith('Benchmark'): - metadata['content_type'] = 'XCCDF Benchmark' - metadata['benchmark_id'] = root.get('id', '') - metadata['benchmark_version'] = root.get('version', '') - + streams = root.xpath(".//ds:data-stream", namespaces=self.namespaces) + metadata["data_stream_count"] = len(streams) + + elif root.tag.endswith("Benchmark"): + metadata["content_type"] = "XCCDF Benchmark" + metadata["benchmark_id"] = root.get("id", "") + metadata["benchmark_version"] = root.get("version", "") + # Extract status - status_elem = root.find('.//xccdf:status', self.namespaces) + status_elem = root.find(".//xccdf:status", self.namespaces) if status_elem is not None: - metadata['status'] = status_elem.text - metadata['status_date'] = status_elem.get('date', '') - + metadata["status"] = status_elem.text + metadata["status_date"] = status_elem.get("date", "") + # Extract metadata elements - metadata_elem = root.find('.//xccdf:metadata', self.namespaces) + metadata_elem = root.find(".//xccdf:metadata", self.namespaces) if metadata_elem is not None: # Extract DC metadata if present - dc_elements = metadata_elem.xpath('.//*[namespace-uri()="http://purl.org/dc/elements/1.1/"]') + dc_elements = metadata_elem.xpath( + './/*[namespace-uri()="http://purl.org/dc/elements/1.1/"]' + ) for dc_elem in dc_elements: - tag_name = dc_elem.tag.split('}')[-1] - metadata[f'dc_{tag_name}'] = dc_elem.text - + tag_name = dc_elem.tag.split("}")[-1] + metadata[f"dc_{tag_name}"] = dc_elem.text + return metadata - + except Exception as e: logger.warning(f"Could not extract XML metadata: {e}") return metadata - + def _extract_datastreams(self, root) -> List[Dict]: """Extract data-stream information""" datastreams = [] - - ds_elements = root.xpath('.//ds:data-stream', namespaces=self.namespaces) + + ds_elements = root.xpath(".//ds:data-stream", namespaces=self.namespaces) for ds_elem in ds_elements: ds_info = { - 'id': ds_elem.get('id', ''), - 'timestamp': ds_elem.get('timestamp', ''), - 'version': ds_elem.get('scap-version', '1.2'), - 'components': [] + "id": ds_elem.get("id", ""), + "timestamp": ds_elem.get("timestamp", ""), + "version": ds_elem.get("scap-version", "1.2"), + "components": [], } - + # Extract component references - components = ds_elem.xpath('.//ds:component-ref', namespaces=self.namespaces) + components = ds_elem.xpath(".//ds:component-ref", namespaces=self.namespaces) for comp in components: - ds_info['components'].append({ - 'id': comp.get('id', ''), - 'href': comp.get('{http://www.w3.org/1999/xlink}href', '') - }) - + ds_info["components"].append( + { + "id": comp.get("id", ""), + "href": comp.get("{http://www.w3.org/1999/xlink}href", ""), + } + ) + datastreams.append(ds_info) - + return datastreams - + def _extract_benchmark_info(self, benchmark_elem) -> Dict: """Extract benchmark information""" benchmark = { - 'id': benchmark_elem.get('id', ''), - 'version': benchmark_elem.get('version', ''), - 'status': '', - 'title': '', - 'description': '' + "id": benchmark_elem.get("id", ""), + "version": benchmark_elem.get("version", ""), + "status": "", + "title": "", + "description": "", } - + # Extract title - title_elem = benchmark_elem.find('.//xccdf:title', self.namespaces) + title_elem = benchmark_elem.find(".//xccdf:title", self.namespaces) if title_elem is not None: - benchmark['title'] = title_elem.text or '' - + benchmark["title"] = title_elem.text or "" + # Extract description - desc_elem = benchmark_elem.find('.//xccdf:description', self.namespaces) + desc_elem = benchmark_elem.find(".//xccdf:description", self.namespaces) if desc_elem is not None: - benchmark['description'] = self._extract_text_content(desc_elem) - + benchmark["description"] = self._extract_text_content(desc_elem) + # Extract status - status_elem = benchmark_elem.find('.//xccdf:status', self.namespaces) + status_elem = benchmark_elem.find(".//xccdf:status", self.namespaces) if status_elem is not None: - benchmark['status'] = status_elem.text or '' - + benchmark["status"] = status_elem.text or "" + return benchmark - + def _extract_profiles_from_tree(self, root) -> List[Dict]: """Extract all profiles from XML tree""" profiles = [] - - profile_elements = root.xpath('.//xccdf:Profile', namespaces=self.namespaces) + + profile_elements = root.xpath(".//xccdf:Profile", namespaces=self.namespaces) for profile_elem in profile_elements: profiles.append(self._extract_profile_from_element(profile_elem)) - + return profiles - + def _extract_rules_with_metadata(self, root) -> List[Dict]: """Extract rules with compliance metadata""" rules = [] - - rule_elements = root.xpath('.//xccdf:Rule', namespaces=self.namespaces) + + rule_elements = root.xpath(".//xccdf:Rule", namespaces=self.namespaces) for rule_elem in rule_elements[:10]: # Limit to first 10 for performance rule = { - 'id': rule_elem.get('id', ''), - 'severity': rule_elem.get('severity', 'unknown'), - 'title': '', - 'description': '', - 'rationale': '', - 'references': [] + "id": rule_elem.get("id", ""), + "severity": rule_elem.get("severity", "unknown"), + "title": "", + "description": "", + "rationale": "", + "references": [], } - + # Extract title - title_elem = rule_elem.find('.//xccdf:title', self.namespaces) + title_elem = rule_elem.find(".//xccdf:title", self.namespaces) if title_elem is not None: - rule['title'] = title_elem.text or '' - + rule["title"] = title_elem.text or "" + # Extract description - desc_elem = rule_elem.find('.//xccdf:description', self.namespaces) + desc_elem = rule_elem.find(".//xccdf:description", self.namespaces) if desc_elem is not None: - rule['description'] = self._extract_text_content(desc_elem)[:200] + '...' - + rule["description"] = self._extract_text_content(desc_elem)[:200] + "..." + # Extract rationale - rat_elem = rule_elem.find('.//xccdf:rationale', self.namespaces) + rat_elem = rule_elem.find(".//xccdf:rationale", self.namespaces) if rat_elem is not None: - rule['rationale'] = self._extract_text_content(rat_elem)[:200] + '...' - + rule["rationale"] = self._extract_text_content(rat_elem)[:200] + "..." + # Extract references (CCE, CCI, etc.) - ref_elements = rule_elem.xpath('.//xccdf:reference', namespaces=self.namespaces) + ref_elements = rule_elem.xpath(".//xccdf:reference", namespaces=self.namespaces) for ref_elem in ref_elements: - rule['references'].append({ - 'href': ref_elem.get('href', ''), - 'text': ref_elem.text or '' - }) - + rule["references"].append( + {"href": ref_elem.get("href", ""), "text": ref_elem.text or ""} + ) + rules.append(rule) - + return rules - + def _extract_cpe_references(self, root) -> List[str]: """Extract CPE (platform) references""" cpe_refs = set() - + # Look for platform elements - platform_elements = root.xpath('.//xccdf:platform', namespaces=self.namespaces) + platform_elements = root.xpath(".//xccdf:platform", namespaces=self.namespaces) for platform in platform_elements: - cpe_ref = platform.get('idref', '') + cpe_ref = platform.get("idref", "") if cpe_ref: cpe_refs.add(cpe_ref) - + return list(cpe_refs) - + def _extract_oval_references(self, root) -> List[str]: """Extract OVAL definition references""" oval_refs = set() - + # Look for check-content-ref elements - check_refs = root.xpath('.//xccdf:check-content-ref', namespaces=self.namespaces) + check_refs = root.xpath(".//xccdf:check-content-ref", namespaces=self.namespaces) for check_ref in check_refs: - href = check_ref.get('href', '') - if 'oval' in href.lower(): + href = check_ref.get("href", "") + if "oval" in href.lower(): oval_refs.add(href) - + return list(oval_refs) - + def _calculate_file_hash(self, file_path: str) -> str: """Calculate SHA-256 hash of file""" sha256_hash = hashlib.sha256() @@ -634,51 +664,61 @@ def _calculate_file_hash(self, file_path: str) -> str: for byte_block in iter(lambda: f.read(4096), b""): sha256_hash.update(byte_block) return sha256_hash.hexdigest() - + def _check_common_issues(self, file_path: str, report: Dict): """Check for common SCAP content issues""" try: tree = etree.parse(file_path) root = tree.getroot() - + # Check for missing profiles - profiles = root.xpath('.//xccdf:Profile', namespaces=self.namespaces) + profiles = root.xpath(".//xccdf:Profile", namespaces=self.namespaces) if not profiles: - report['warnings'].append("No profiles found in content") - + report["warnings"].append("No profiles found in content") + # Check for platform specifications - platforms = root.xpath('.//xccdf:platform', namespaces=self.namespaces) + platforms = root.xpath(".//xccdf:platform", namespaces=self.namespaces) if not platforms: - report['warnings'].append("No platform specifications found") - + report["warnings"].append("No platform specifications found") + # Check for large rule sets - rules = root.xpath('.//xccdf:Rule', namespaces=self.namespaces) + rules = root.xpath(".//xccdf:Rule", namespaces=self.namespaces) if len(rules) > 1000: - report['info']['rule_count'] = len(rules) - report['warnings'].append(f"Large rule set ({len(rules)} rules) may impact performance") - + report["info"]["rule_count"] = len(rules) + report["warnings"].append( + f"Large rule set ({len(rules)} rules) may impact performance" + ) + # Check for OVAL content references - oval_refs = root.xpath('.//xccdf:check-content-ref[@href]', namespaces=self.namespaces) + oval_refs = root.xpath(".//xccdf:check-content-ref[@href]", namespaces=self.namespaces) if oval_refs: - report['info']['has_oval_content'] = True - report['info']['oval_ref_count'] = len(oval_refs) - + report["info"]["has_oval_content"] = True + report["info"]["oval_ref_count"] = len(oval_refs) + except Exception as e: - report['warnings'].append(f"Could not perform content checks: {str(e)}") - + report["warnings"].append(f"Could not perform content checks: {str(e)}") + def _generate_recommendations(self, report: Dict): """Generate recommendations based on validation report""" - if report['validation_status'] == 'valid_datastream': - report['recommendations'].append("Content is valid SCAP data-stream format") - elif report['validation_status'] == 'valid_xccdf': - report['recommendations'].append("Consider converting to SCAP data-stream format for better tool support") - - if report.get('warnings'): - if "No profiles found" in str(report['warnings']): - report['recommendations'].append("Define profiles to group rules for different use cases") - - if "Large rule set" in str(report['warnings']): - report['recommendations'].append("Consider creating focused profiles for different compliance requirements") - - if report['info'].get('has_oval_content'): - report['recommendations'].append("Ensure OVAL definitions are accessible for automated checking") \ No newline at end of file + if report["validation_status"] == "valid_datastream": + report["recommendations"].append("Content is valid SCAP data-stream format") + elif report["validation_status"] == "valid_xccdf": + report["recommendations"].append( + "Consider converting to SCAP data-stream format for better tool support" + ) + + if report.get("warnings"): + if "No profiles found" in str(report["warnings"]): + report["recommendations"].append( + "Define profiles to group rules for different use cases" + ) + + if "Large rule set" in str(report["warnings"]): + report["recommendations"].append( + "Consider creating focused profiles for different compliance requirements" + ) + + if report["info"].get("has_oval_content"): + report["recommendations"].append( + "Ensure OVAL definitions are accessible for automated checking" + ) diff --git a/backend/app/services/scap_repository.py b/backend/app/services/scap_repository.py index 92dddfe2..af9ed93e 100644 --- a/backend/app/services/scap_repository.py +++ b/backend/app/services/scap_repository.py @@ -2,6 +2,7 @@ SCAP Repository Management Service Handles automatic downloading and synchronization of SCAP content from various repositories """ + import asyncio import aiohttp import logging @@ -18,6 +19,7 @@ logger = logging.getLogger(__name__) + @dataclass class RepositoryConfig: id: str @@ -29,6 +31,7 @@ class RepositoryConfig: last_sync: Optional[datetime] = None credentials: Optional[Dict[str, str]] = None + @dataclass class ContentMetadata: name: str @@ -44,19 +47,20 @@ class ContentMetadata: size_bytes: int last_modified: datetime + class SCAPRepositoryManager: """Manages SCAP content repositories and automatic synchronization""" - + def __init__(self): self.repositories: Dict[str, RepositoryConfig] = {} self.sync_running = False self.last_global_sync: Optional[datetime] = None self.content_cache_dir = Path("/app/data/scap_cache") self.content_cache_dir.mkdir(parents=True, exist_ok=True) - + # Initialize default repositories self._setup_default_repositories() - + def _setup_default_repositories(self): """Setup default SCAP content repositories""" default_repos = [ @@ -66,15 +70,15 @@ def _setup_default_repositories(self): url="https://ncp.nist.gov/repository", type="official", enabled=True, - os_families=["rhel", "ubuntu", "windows"] + os_families=["rhel", "ubuntu", "windows"], ), RepositoryConfig( id="redhat_security", name="Red Hat Security Data", url="https://access.redhat.com/security/data/oval", - type="official", + type="official", enabled=True, - os_families=["rhel", "centos"] + os_families=["rhel", "centos"], ), RepositoryConfig( id="ubuntu_security", @@ -82,33 +86,35 @@ def _setup_default_repositories(self): url="https://people.canonical.com/~ubuntu-security/oval", type="official", enabled=True, - os_families=["ubuntu", "debian"] - ) + os_families=["ubuntu", "debian"], + ), ] - + for repo in default_repos: self.repositories[repo.id] = repo - - async def sync_repositories(self, db: Session, repository_ids: Optional[List[str]] = None) -> Dict[str, str]: + + async def sync_repositories( + self, db: Session, repository_ids: Optional[List[str]] = None + ) -> Dict[str, str]: """ Synchronize content from repositories Returns dict of repository_id -> status """ if self.sync_running: return {"error": "Sync already in progress"} - + self.sync_running = True results = {} - + try: repos_to_sync = ( [self.repositories[rid] for rid in repository_ids if rid in self.repositories] if repository_ids else [repo for repo in self.repositories.values() if repo.enabled] ) - + logger.info(f"Starting sync for {len(repos_to_sync)} repositories") - + for repo in repos_to_sync: try: result = await self._sync_repository(db, repo) @@ -117,28 +123,28 @@ async def sync_repositories(self, db: Session, repository_ids: Optional[List[str except Exception as e: logger.error(f"Failed to sync repository {repo.name}: {e}") results[repo.id] = f"error: {str(e)}" - + self.last_global_sync = datetime.utcnow() - + finally: self.sync_running = False - + return results - + async def _sync_repository(self, db: Session, repo: RepositoryConfig) -> str: """Sync a single repository""" logger.info(f"Syncing repository: {repo.name}") - + # Get repository catalog/index content_list = await self._fetch_repository_catalog(repo) - + new_content = 0 updated_content = 0 - + for content_meta in content_list: # Check if content already exists existing = await self._get_existing_content(db, content_meta) - + if not existing: # Download and import new content if await self._download_and_import_content(db, repo, content_meta): @@ -147,16 +153,16 @@ async def _sync_repository(self, db: Session, repo: RepositoryConfig) -> str: # Update existing content if await self._update_existing_content(db, repo, content_meta, existing): updated_content += 1 - + return f"synced: {new_content} new, {updated_content} updated" - + async def _fetch_repository_catalog(self, repo: RepositoryConfig) -> List[ContentMetadata]: """Fetch the catalog/index of available content from repository""" - + # This is a simplified implementation - real repositories would have # standardized APIs or catalog formats catalog_url = f"{repo.url}/catalog.json" - + try: async with aiohttp.ClientSession() as session: async with session.get(catalog_url, timeout=30) as response: @@ -169,34 +175,38 @@ async def _fetch_repository_catalog(self, repo: RepositoryConfig) -> List[Conten except Exception as e: logger.warning(f"Failed to fetch catalog for {repo.name}: {e}") return await self._discover_content(repo) - - def _parse_catalog_data(self, catalog_data: Dict, repo: RepositoryConfig) -> List[ContentMetadata]: + + def _parse_catalog_data( + self, catalog_data: Dict, repo: RepositoryConfig + ) -> List[ContentMetadata]: """Parse catalog JSON into ContentMetadata objects""" content_list = [] - - for item in catalog_data.get('content', []): + + for item in catalog_data.get("content", []): # Skip content not matching repository's OS families - if item.get('os_family') not in repo.os_families: + if item.get("os_family") not in repo.os_families: continue - + content_meta = ContentMetadata( - name=item['name'], - filename=item['filename'], - content_type=item.get('content_type', 'datastream'), - description=item.get('description', ''), - version=item.get('version', '1.0'), - os_family=item['os_family'], - os_version=item.get('os_version', ''), - compliance_framework=item.get('compliance_framework', 'unknown'), + name=item["name"], + filename=item["filename"], + content_type=item.get("content_type", "datastream"), + description=item.get("description", ""), + version=item.get("version", "1.0"), + os_family=item["os_family"], + os_version=item.get("os_version", ""), + compliance_framework=item.get("compliance_framework", "unknown"), url=f"{repo.url}/{item['filename']}", - checksum=item.get('checksum', ''), - size_bytes=item.get('size_bytes', 0), - last_modified=datetime.fromisoformat(item.get('last_modified', datetime.utcnow().isoformat())) + checksum=item.get("checksum", ""), + size_bytes=item.get("size_bytes", 0), + last_modified=datetime.fromisoformat( + item.get("last_modified", datetime.utcnow().isoformat()) + ), ) content_list.append(content_meta) - + return content_list - + async def _discover_content(self, repo: RepositoryConfig) -> List[ContentMetadata]: """Fallback content discovery for repositories without catalogs""" # This would implement various discovery methods: @@ -204,9 +214,9 @@ async def _discover_content(self, repo: RepositoryConfig) -> List[ContentMetadat # - RSS/Atom feeds # - API endpoints # - File pattern matching - + logger.info(f"Discovering content for {repo.name} (no catalog available)") - + # Mock discovery for demonstration if "nist" in repo.url.lower(): return await self._discover_nist_content(repo) @@ -214,9 +224,9 @@ async def _discover_content(self, repo: RepositoryConfig) -> List[ContentMetadat return await self._discover_redhat_content(repo) elif "ubuntu" in repo.url.lower(): return await self._discover_ubuntu_content(repo) - + return [] - + async def _discover_nist_content(self, repo: RepositoryConfig) -> List[ContentMetadata]: """Discover NIST SCAP content""" # Mock NIST content discovery @@ -233,7 +243,7 @@ async def _discover_nist_content(self, repo: RepositoryConfig) -> List[ContentMe url=f"{repo.url}/rhel9/U_RHEL_9_STIG_V1R5_Manual-xccdf.xml", checksum="abc123...", size_bytes=2048576, - last_modified=datetime.utcnow() - timedelta(days=7) + last_modified=datetime.utcnow() - timedelta(days=7), ), ContentMetadata( name="Ubuntu 22.04 CIS Benchmark", @@ -247,35 +257,42 @@ async def _discover_nist_content(self, repo: RepositoryConfig) -> List[ContentMe url=f"{repo.url}/ubuntu/ubuntu2204-cis-v1.0.0-xccdf.xml", checksum="def456...", size_bytes=1536000, - last_modified=datetime.utcnow() - timedelta(days=14) - ) + last_modified=datetime.utcnow() - timedelta(days=14), + ), ] - + return [c for c in content_list if c.os_family in repo.os_families] - + async def _discover_redhat_content(self, repo: RepositoryConfig) -> List[ContentMetadata]: """Discover Red Hat security content""" # Mock Red Hat content discovery return [] - + async def _discover_ubuntu_content(self, repo: RepositoryConfig) -> List[ContentMetadata]: - """Discover Ubuntu security content""" + """Discover Ubuntu security content""" # Mock Ubuntu content discovery return [] - - async def _get_existing_content(self, db: Session, content_meta: ContentMetadata) -> Optional[Dict]: + + async def _get_existing_content( + self, db: Session, content_meta: ContentMetadata + ) -> Optional[Dict]: """Check if content already exists in database""" try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, name, version, checksum, updated_at FROM scap_content WHERE name = :name AND os_family = :os_family AND os_version = :os_version - """), { - "name": content_meta.name, - "os_family": content_meta.os_family, - "os_version": content_meta.os_version - }) - + """ + ), + { + "name": content_meta.name, + "os_family": content_meta.os_family, + "os_version": content_meta.os_version, + }, + ) + row = result.fetchone() if row: return { @@ -283,55 +300,58 @@ async def _get_existing_content(self, db: Session, content_meta: ContentMetadata "name": row.name, "version": row.version, "checksum": row.checksum, - "updated_at": row.updated_at + "updated_at": row.updated_at, } return None except Exception as e: logger.error(f"Error checking existing content: {e}") return None - + def _should_update_content(self, existing: Dict, content_meta: ContentMetadata) -> bool: """Determine if existing content should be updated""" # Check version if content_meta.version != existing.get("version"): return True - + # Check checksum if available if content_meta.checksum and content_meta.checksum != existing.get("checksum"): return True - + # Check age (update if repository content is newer than 30 days) if existing.get("updated_at"): age = datetime.utcnow() - existing["updated_at"] if age > timedelta(days=30): return True - + return False - - async def _download_and_import_content(self, db: Session, repo: RepositoryConfig, - content_meta: ContentMetadata) -> bool: + + async def _download_and_import_content( + self, db: Session, repo: RepositoryConfig, content_meta: ContentMetadata + ) -> bool: """Download content file and import to database""" try: # Download content content_data = await self._download_content_file(content_meta.url) - + # Validate checksum if provided if content_meta.checksum: actual_checksum = hashlib.sha256(content_data).hexdigest() if actual_checksum != content_meta.checksum: logger.warning(f"Checksum mismatch for {content_meta.name}") return False - + # Parse and validate SCAP content profiles = await self._extract_profiles_from_content(content_data) - + # Save to cache directory cache_file = self.content_cache_dir / f"{content_meta.filename}" cache_file.write_bytes(content_data) - + # Insert into database current_time = datetime.utcnow() - result = db.execute(text(""" + result = db.execute( + text( + """ INSERT INTO scap_content (name, filename, content_type, description, version, profiles, os_family, os_version, compliance_framework, source, status, @@ -340,75 +360,84 @@ async def _download_and_import_content(self, db: Session, repo: RepositoryConfig :os_family, :os_version, :compliance_framework, 'repository', 'current', :checksum, :file_size, :file_path, :uploaded_at, :uploaded_by) RETURNING id - """), { - "name": content_meta.name, - "filename": content_meta.filename, - "content_type": content_meta.content_type, - "description": content_meta.description, - "version": content_meta.version, - "profiles": json.dumps(profiles), - "os_family": content_meta.os_family, - "os_version": content_meta.os_version, - "compliance_framework": content_meta.compliance_framework, - "checksum": content_meta.checksum, - "file_size": len(content_data), - "file_path": str(cache_file), - "uploaded_at": current_time, - "uploaded_by": 1 # System user - }) - + """ + ), + { + "name": content_meta.name, + "filename": content_meta.filename, + "content_type": content_meta.content_type, + "description": content_meta.description, + "version": content_meta.version, + "profiles": json.dumps(profiles), + "os_family": content_meta.os_family, + "os_version": content_meta.os_version, + "compliance_framework": content_meta.compliance_framework, + "checksum": content_meta.checksum, + "file_size": len(content_data), + "file_path": str(cache_file), + "uploaded_at": current_time, + "uploaded_by": 1, # System user + }, + ) + content_id = result.fetchone().id db.commit() - + logger.info(f"Imported new content: {content_meta.name} (ID: {content_id})") return True - + except Exception as e: logger.error(f"Failed to download and import {content_meta.name}: {e}") db.rollback() return False - - async def _update_existing_content(self, db: Session, repo: RepositoryConfig, - content_meta: ContentMetadata, existing: Dict) -> bool: + + async def _update_existing_content( + self, db: Session, repo: RepositoryConfig, content_meta: ContentMetadata, existing: Dict + ) -> bool: """Update existing content with new version""" try: # Download updated content content_data = await self._download_content_file(content_meta.url) - + # Parse profiles profiles = await self._extract_profiles_from_content(content_data) - + # Update cache file cache_file = self.content_cache_dir / f"{content_meta.filename}" cache_file.write_bytes(content_data) - + # Update database record - db.execute(text(""" + db.execute( + text( + """ UPDATE scap_content SET version = :version, profiles = :profiles, checksum = :checksum, file_size = :file_size, file_path = :file_path, updated_at = :updated_at, status = 'current' WHERE id = :id - """), { - "id": existing["id"], - "version": content_meta.version, - "profiles": json.dumps(profiles), - "checksum": content_meta.checksum, - "file_size": len(content_data), - "file_path": str(cache_file), - "updated_at": datetime.utcnow() - }) - + """ + ), + { + "id": existing["id"], + "version": content_meta.version, + "profiles": json.dumps(profiles), + "checksum": content_meta.checksum, + "file_size": len(content_data), + "file_path": str(cache_file), + "updated_at": datetime.utcnow(), + }, + ) + db.commit() - + logger.info(f"Updated content: {content_meta.name} to version {content_meta.version}") return True - + except Exception as e: logger.error(f"Failed to update {content_meta.name}: {e}") db.rollback() return False - + async def _download_content_file(self, url: str) -> bytes: """Download content file from URL""" async with aiohttp.ClientSession() as session: @@ -417,35 +446,43 @@ async def _download_content_file(self, url: str) -> bytes: return await response.read() else: raise Exception(f"Failed to download: HTTP {response.status}") - + async def _extract_profiles_from_content(self, content_data: bytes) -> List[Dict]: """Extract profile information from SCAP content""" try: # Parse XML content root = ET.fromstring(content_data) - + # Find profiles (simplified - real implementation would handle various formats) profiles = [] - + # XCCDF profiles - for profile in root.findall('.//{http://checklists.nist.gov/xccdf/1.2}Profile'): - profile_id = profile.get('id', '') - title_elem = profile.find('.//{http://checklists.nist.gov/xccdf/1.2}title') - desc_elem = profile.find('.//{http://checklists.nist.gov/xccdf/1.2}description') - - profiles.append({ - 'id': profile_id, - 'title': title_elem.text if title_elem is not None else profile_id, - 'description': desc_elem.text if desc_elem is not None else '' - }) - + for profile in root.findall(".//{http://checklists.nist.gov/xccdf/1.2}Profile"): + profile_id = profile.get("id", "") + title_elem = profile.find(".//{http://checklists.nist.gov/xccdf/1.2}title") + desc_elem = profile.find(".//{http://checklists.nist.gov/xccdf/1.2}description") + + profiles.append( + { + "id": profile_id, + "title": title_elem.text if title_elem is not None else profile_id, + "description": desc_elem.text if desc_elem is not None else "", + } + ) + return profiles - + except Exception as e: logger.warning(f"Failed to extract profiles: {e}") # Return basic profile if parsing fails - return [{'id': 'default', 'title': 'Default Profile', 'description': 'Default security profile'}] - + return [ + { + "id": "default", + "title": "Default Profile", + "description": "Default security profile", + } + ] + def get_repository_status(self) -> Dict: """Get status of all repositories""" return { @@ -456,24 +493,27 @@ def get_repository_status(self) -> Dict: "type": repo.type, "enabled": repo.enabled, "last_sync": repo.last_sync.isoformat() if repo.last_sync else None, - "os_families": repo.os_families + "os_families": repo.os_families, } for repo in self.repositories.values() ], "sync_running": self.sync_running, - "last_global_sync": self.last_global_sync.isoformat() if self.last_global_sync else None + "last_global_sync": ( + self.last_global_sync.isoformat() if self.last_global_sync else None + ), } - + def enable_repository(self, repo_id: str, enabled: bool = True): """Enable or disable a repository""" if repo_id in self.repositories: self.repositories[repo_id].enabled = enabled logger.info(f"Repository {repo_id} {'enabled' if enabled else 'disabled'}") - + def add_custom_repository(self, config: RepositoryConfig): """Add a custom repository""" self.repositories[config.id] = config logger.info(f"Added custom repository: {config.name}") + # Global repository manager instance -scap_repository_manager = SCAPRepositoryManager() \ No newline at end of file +scap_repository_manager = SCAPRepositoryManager() diff --git a/backend/app/services/scap_scanner.py b/backend/app/services/scap_scanner.py index bc736520..3cc8f1a6 100644 --- a/backend/app/services/scap_scanner.py +++ b/backend/app/services/scap_scanner.py @@ -2,6 +2,7 @@ OpenSCAP Scanner Service Handles SCAP content processing and scanning operations """ + import os import subprocess import tempfile @@ -26,95 +27,102 @@ class SCAPContentError(Exception): """Exception raised for SCAP content processing errors""" + pass class ScanExecutionError(Exception): """Exception raised for scan execution errors""" + pass class SCAPScanner: """Main SCAP scanning service""" - - def __init__(self, content_dir: Optional[str] = None, - results_dir: Optional[str] = None): + + def __init__(self, content_dir: Optional[str] = None, results_dir: Optional[str] = None): settings = get_settings() - + # Use provided paths or fall back to configuration content_path = content_dir or settings.scap_content_dir results_path = results_dir or settings.scan_results_dir - + self.content_dir = Path(content_path) self.results_dir = Path(results_path) - + # Create directories if they don't exist try: self.content_dir.mkdir(parents=True, exist_ok=True) self.results_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"SCAP Scanner initialized - Content: {self.content_dir}, Results: {self.results_dir}") + logger.info( + f"SCAP Scanner initialized - Content: {self.content_dir}, Results: {self.results_dir}" + ) except Exception as e: logger.error(f"Failed to create SCAP directories: {e}") raise SCAPContentError(f"Directory creation failed: {str(e)}") - + def validate_scap_content(self, file_path: str) -> Dict: """Validate SCAP content file and extract metadata""" try: logger.info(f"Validating SCAP content: {file_path}") - + # First check if file exists and is readable if not os.path.exists(file_path): raise SCAPContentError(f"File not found: {file_path}") - + # Use oscap to validate the file - result = subprocess.run([ - 'oscap', 'info', file_path - ], capture_output=True, text=True, timeout=30) - + result = subprocess.run( + ["oscap", "info", file_path], capture_output=True, text=True, timeout=30 + ) + if result.returncode != 0: raise SCAPContentError(f"Invalid SCAP content: {result.stderr}") - + # Parse the output to extract information info = self._parse_oscap_info(result.stdout) logger.info(f"SCAP content validated successfully: {info.get('title', 'Unknown')}") - + return info - + except subprocess.TimeoutExpired: raise SCAPContentError("Timeout validating SCAP content") except Exception as e: logger.error(f"Error validating SCAP content: {e}") raise SCAPContentError(f"Validation failed: {str(e)}") - + def extract_profiles(self, file_path: str) -> List[Dict]: """Extract available profiles from SCAP content""" try: logger.info(f"Extracting profiles from: {file_path}") - - result = subprocess.run([ - 'oscap', 'info', '--profiles', file_path - ], capture_output=True, text=True, timeout=30) - + + result = subprocess.run( + ["oscap", "info", "--profiles", file_path], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: raise SCAPContentError(f"Failed to extract profiles: {result.stderr}") - + profiles = self._parse_profiles(result.stdout) logger.info(f"Extracted {len(profiles)} profiles") - + return profiles - + except subprocess.TimeoutExpired: raise SCAPContentError("Timeout extracting profiles") except Exception as e: logger.error(f"Error extracting profiles: {e}") raise SCAPContentError(f"Profile extraction failed: {str(e)}") - - def test_ssh_connection(self, hostname: str, port: int, username: str, - auth_method: str, credential: str) -> Dict: + + def test_ssh_connection( + self, hostname: str, port: int, username: str, auth_method: str, credential: str + ) -> Dict: """Test SSH connection to remote host""" try: logger.info(f"Testing SSH connection to {username}@{hostname}:{port}") - + ssh = paramiko.SSHClient() # Security Fix: Use strict host key checking instead of AutoAddPolicy # AutoAddPolicy automatically accepts unknown host keys, vulnerable to MITM attacks @@ -122,31 +130,37 @@ def test_ssh_connection(self, hostname: str, port: int, username: str, # Load system and user host keys for validation try: ssh.load_system_host_keys() # Load from /etc/ssh/ssh_known_hosts - ssh.load_host_keys(os.path.expanduser('~/.ssh/known_hosts')) # Load user known_hosts + ssh.load_host_keys( + os.path.expanduser("~/.ssh/known_hosts") + ) # Load user known_hosts except FileNotFoundError: - logger.warning("No known_hosts files found - SSH connections may fail without proper host key management") - + logger.warning( + "No known_hosts files found - SSH connections may fail without proper host key management" + ) + # Connect based on auth method if auth_method == "password": - ssh.connect(hostname, port=port, username=username, - password=credential, timeout=10) + ssh.connect(hostname, port=port, username=username, password=credential, timeout=10) elif auth_method in ["ssh-key", "ssh_key"]: # Handle SSH key authentication using new utility try: # Validate SSH key first validation_result = validate_ssh_key(credential) if not validation_result.is_valid: - logger.error(f"Invalid SSH key for {hostname}: {validation_result.error_message}") + logger.error( + f"Invalid SSH key for {hostname}: {validation_result.error_message}" + ) raise SSHException(f"Invalid SSH key: {validation_result.error_message}") - + # Log any warnings if validation_result.warnings: - logger.warning(f"SSH key warnings for {hostname}: {'; '.join(validation_result.warnings)}") - + logger.warning( + f"SSH key warnings for {hostname}: {'; '.join(validation_result.warnings)}" + ) + # Parse key using unified parser key = parse_ssh_key(credential) - ssh.connect(hostname, port=port, username=username, - pkey=key, timeout=10) + ssh.connect(hostname, port=port, username=username, pkey=key, timeout=10) except SSHKeyError as e: logger.error(f"SSH key parsing failed for {hostname}: {e}") raise SSHException(f"SSH key error: {str(e)}") @@ -155,220 +169,254 @@ def test_ssh_connection(self, hostname: str, port: int, username: str, raise else: raise SCAPContentError(f"Unsupported auth method: {auth_method}") - + # Test basic command execution stdin, stdout, stderr = ssh.exec_command('echo "OpenWatch SSH Test"') output = stdout.read().decode() error = stderr.read().decode() - + # Check if oscap is available on remote host - stdin, stdout, stderr = ssh.exec_command('oscap --version') + stdin, stdout, stderr = ssh.exec_command("oscap --version") oscap_output = stdout.read().decode() oscap_error = stderr.read().decode() - + oscap_available = stdout.channel.recv_exit_status() == 0 - + ssh.close() - + result = { "success": True, "message": "SSH connection successful", "oscap_available": oscap_available, "oscap_version": oscap_output.strip() if oscap_available else None, - "test_output": output.strip() + "test_output": output.strip(), } - + if not oscap_available: result["warning"] = "OpenSCAP not found on remote host" - + logger.info(f"SSH test successful: {hostname}") return result - + except SSHException as e: logger.error(f"SSH connection failed: {e}") return { "success": False, "message": f"SSH connection failed: {str(e)}", - "oscap_available": False + "oscap_available": False, } except Exception as e: logger.error(f"SSH test error: {e}") return { "success": False, "message": f"Connection test failed: {str(e)}", - "oscap_available": False + "oscap_available": False, } - - def execute_local_scan(self, content_path: str, profile_id: str, - scan_id: str, rule_id: str = None) -> Dict: + + def execute_local_scan( + self, content_path: str, profile_id: str, scan_id: str, rule_id: str = None + ) -> Dict: """Execute SCAP scan on local system""" try: logger.info(f"Starting local scan: {scan_id}") - + # Create result directory for this scan scan_dir = self.results_dir / scan_id scan_dir.mkdir(exist_ok=True) - + # Define output files xml_result = scan_dir / "results.xml" html_report = scan_dir / "report.html" arf_result = scan_dir / "results.arf.xml" - + # Execute oscap scan cmd = [ - 'oscap', 'xccdf', 'eval', - '--profile', profile_id, - '--results', str(xml_result), - '--report', str(html_report), - '--results-arf', str(arf_result) + "oscap", + "xccdf", + "eval", + "--profile", + profile_id, + "--results", + str(xml_result), + "--report", + str(html_report), + "--results-arf", + str(arf_result), ] - + # Add rule-specific scanning if rule_id is provided if rule_id: - cmd.extend(['--rule', rule_id]) + cmd.extend(["--rule", rule_id]) logger.info(f"Scanning specific rule: {rule_id}") - + cmd.append(content_path) - + logger.info(f"Executing: {' '.join(cmd)}") - + result = subprocess.run( - cmd, capture_output=True, text=True, - timeout=1800 # 30 minutes timeout + cmd, capture_output=True, text=True, timeout=1800 # 30 minutes timeout ) - + # Parse results with content file for remediation extraction scan_results = self._parse_scan_results(str(xml_result), content_path) - scan_results.update({ - "scan_id": scan_id, - "scan_type": "local", - "exit_code": result.returncode, - "stdout": result.stdout, - "stderr": result.stderr, - "xml_result": str(xml_result), - "html_report": str(html_report), - "arf_result": str(arf_result) - }) - + scan_results.update( + { + "scan_id": scan_id, + "scan_type": "local", + "exit_code": result.returncode, + "stdout": result.stdout, + "stderr": result.stderr, + "xml_result": str(xml_result), + "html_report": str(html_report), + "arf_result": str(arf_result), + } + ) + logger.info(f"Local scan completed: {scan_id}") return scan_results - + except subprocess.TimeoutExpired: logger.error(f"Scan timeout: {scan_id}") raise ScanExecutionError("Scan execution timeout") except Exception as e: logger.error(f"Local scan failed: {e}") raise ScanExecutionError(f"Scan execution failed: {str(e)}") - - def execute_remote_scan(self, hostname: str, port: int, username: str, - auth_method: str, credential: str, content_path: str, - profile_id: str, scan_id: str, rule_id: str = None) -> Dict: + + def execute_remote_scan( + self, + hostname: str, + port: int, + username: str, + auth_method: str, + credential: str, + content_path: str, + profile_id: str, + scan_id: str, + rule_id: str = None, + ) -> Dict: """Execute SCAP scan on remote system via SSH""" try: logger.info(f"Starting remote scan: {scan_id} on {hostname}") - + # Create result directory for this scan scan_dir = self.results_dir / scan_id scan_dir.mkdir(exist_ok=True) - + # Define output files xml_result = scan_dir / "results.xml" html_report = scan_dir / "report.html" arf_result = scan_dir / "results.arf.xml" - + # Prepare SSH connection parameters ssh_options = [ - '-o', 'StrictHostKeyChecking=no', - '-o', 'UserKnownHostsFile=/dev/null', - '-p', str(port) + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + "-p", + str(port), ] - + # For all authentication methods, use paramiko for SSH execution # oscap-ssh is not available in the standard OpenSCAP package return self._execute_remote_scan_with_paramiko( - hostname, port, username, auth_method, credential, content_path, - profile_id, scan_id, xml_result, html_report, arf_result, rule_id + hostname, + port, + username, + auth_method, + credential, + content_path, + profile_id, + scan_id, + xml_result, + html_report, + arf_result, + rule_id, ) - + except subprocess.TimeoutExpired: logger.error(f"Remote scan timeout: {scan_id}") raise ScanExecutionError("Remote scan execution timeout") except Exception as e: logger.error(f"Remote scan failed: {e}") raise ScanExecutionError(f"Remote scan execution failed: {str(e)}") - - def get_system_info(self, hostname: str = None, port: int = 22, - username: str = None, auth_method: str = None, - credential: str = None) -> Dict: + + def get_system_info( + self, + hostname: str = None, + port: int = 22, + username: str = None, + auth_method: str = None, + credential: str = None, + ) -> Dict: """Get system information from local or remote host""" try: if hostname: # Remote system info - return self._get_remote_system_info(hostname, port, username, - auth_method, credential) + return self._get_remote_system_info( + hostname, port, username, auth_method, credential + ) else: # Local system info return self._get_local_system_info() - + except Exception as e: logger.error(f"Error getting system info: {e}") return {"error": str(e)} - + def _parse_oscap_info(self, info_output: str) -> Dict: """Parse oscap info command output""" info = {} - lines = info_output.split('\n') - + lines = info_output.split("\n") + for line in lines: line = line.strip() - if ':' in line: - key, value = line.split(':', 1) - key = key.strip().lower().replace(' ', '_') + if ":" in line: + key, value = line.split(":", 1) + key = key.strip().lower().replace(" ", "_") value = value.strip() info[key] = value - + return info - + def _parse_profiles(self, profiles_output: str) -> List[Dict]: """Parse profiles from oscap info --profiles output""" profiles = [] - lines = profiles_output.split('\n') - + lines = profiles_output.split("\n") + current_profile = None for line in lines: line = line.strip() - if line.startswith('Profile ID:'): + if line.startswith("Profile ID:"): if current_profile: profiles.append(current_profile) current_profile = { - 'id': line.split(':', 1)[1].strip(), - 'title': '', - 'description': '' + "id": line.split(":", 1)[1].strip(), + "title": "", + "description": "", } - elif line.startswith('Title:') and current_profile: - current_profile['title'] = line.split(':', 1)[1].strip() - elif line.startswith('Description:') and current_profile: - current_profile['description'] = line.split(':', 1)[1].strip() - + elif line.startswith("Title:") and current_profile: + current_profile["title"] = line.split(":", 1)[1].strip() + elif line.startswith("Description:") and current_profile: + current_profile["description"] = line.split(":", 1)[1].strip() + if current_profile: profiles.append(current_profile) - + return profiles - + def _parse_scan_results(self, xml_file: str, content_file: str = None) -> Dict: """Parse SCAP scan results from XML file with enhanced remediation extraction""" try: if not os.path.exists(xml_file): return {"error": "Results file not found"} - + tree = etree.parse(xml_file) root = tree.getroot() - + # Extract basic statistics - namespaces = { - 'xccdf': 'http://checklists.nist.gov/xccdf/1.2' - } - + namespaces = {"xccdf": "http://checklists.nist.gov/xccdf/1.2"} + results = { "timestamp": datetime.now().isoformat(), "rules_total": 0, @@ -380,9 +428,9 @@ def _parse_scan_results(self, xml_file: str, content_file: str = None) -> Dict: "rules_notchecked": 0, "score": 0.0, "failed_rules": [], - "rule_details": [] # Enhanced rule details with remediation + "rule_details": [], # Enhanced rule details with remediation } - + # Load SCAP content for remediation extraction if available content_tree = None if content_file and os.path.exists(content_file): @@ -391,21 +439,23 @@ def _parse_scan_results(self, xml_file: str, content_file: str = None) -> Dict: logger.info(f"Loaded SCAP content for remediation extraction: {content_file}") except Exception as e: logger.warning(f"Could not load SCAP content file: {e}") - + # Count rule results and extract detailed information - rule_results = root.xpath('//xccdf:rule-result', namespaces=namespaces) + rule_results = root.xpath("//xccdf:rule-result", namespaces=namespaces) results["rules_total"] = len(rule_results) - + for rule_result in rule_results: - result_elem = rule_result.find('xccdf:result', namespaces) + result_elem = rule_result.find("xccdf:result", namespaces) if result_elem is not None: result_value = result_elem.text - rule_id = rule_result.get('idref', '') - severity = rule_result.get('severity', 'unknown') - + rule_id = rule_result.get("idref", "") + severity = rule_result.get("severity", "unknown") + # Extract remediation information from SCAP content - remediation_info = self._extract_rule_remediation(rule_id, content_tree, namespaces) - + remediation_info = self._extract_rule_remediation( + rule_id, content_tree, namespaces + ) + # Create detailed rule entry rule_detail = { "rule_id": rule_id, @@ -415,124 +465,123 @@ def _parse_scan_results(self, xml_file: str, content_file: str = None) -> Dict: "description": remediation_info.get("description", ""), "rationale": remediation_info.get("rationale", ""), "remediation": remediation_info.get("remediation", {}), - "references": remediation_info.get("references", []) + "references": remediation_info.get("references", []), } - + results["rule_details"].append(rule_detail) - + # Count by result type - if result_value == 'pass': + if result_value == "pass": results["rules_passed"] += 1 - elif result_value == 'fail': + elif result_value == "fail": results["rules_failed"] += 1 # Extract failed rule info (backward compatibility) - results["failed_rules"].append({ - "rule_id": rule_id, - "severity": severity - }) - elif result_value == 'error': + results["failed_rules"].append({"rule_id": rule_id, "severity": severity}) + elif result_value == "error": results["rules_error"] += 1 - elif result_value == 'unknown': + elif result_value == "unknown": results["rules_unknown"] += 1 - elif result_value == 'notapplicable': + elif result_value == "notapplicable": results["rules_notapplicable"] += 1 - elif result_value == 'notchecked': + elif result_value == "notchecked": results["rules_notchecked"] += 1 - + # Calculate score if results["rules_total"] > 0: - results["score"] = (results["rules_passed"] / - (results["rules_passed"] + results["rules_failed"])) * 100 - + results["score"] = ( + results["rules_passed"] / (results["rules_passed"] + results["rules_failed"]) + ) * 100 + return results - + except Exception as e: logger.error(f"Error parsing scan results: {e}") return {"error": f"Failed to parse results: {str(e)}"} - + def _extract_rule_remediation(self, rule_id: str, content_tree, namespaces: Dict) -> Dict: """Extract detailed rule information and remediation from SCAP content""" remediation_info = { "title": "", "description": "", - "rationale": "", + "rationale": "", "remediation": {}, - "references": [] + "references": [], } - + if not content_tree: return remediation_info - + try: # Find the rule definition in the SCAP content rule_xpath = f'.//xccdf:Rule[@id="{rule_id}"]' rules = content_tree.xpath(rule_xpath, namespaces=namespaces) - + if not rules: logger.debug(f"Rule not found in SCAP content: {rule_id}") return remediation_info - + rule = rules[0] - + # Extract title - title_elem = rule.find('xccdf:title', namespaces) + title_elem = rule.find("xccdf:title", namespaces) if title_elem is not None: remediation_info["title"] = self._extract_text_content(title_elem) - + # Extract description - desc_elem = rule.find('xccdf:description', namespaces) + desc_elem = rule.find("xccdf:description", namespaces) if desc_elem is not None: remediation_info["description"] = self._extract_text_content(desc_elem) - + # Extract rationale - rationale_elem = rule.find('xccdf:rationale', namespaces) + rationale_elem = rule.find("xccdf:rationale", namespaces) if rationale_elem is not None: remediation_info["rationale"] = self._extract_text_content(rationale_elem) - + # Extract remediation information remediation_info["remediation"] = self._extract_remediation_details(rule, namespaces) - + # Extract references remediation_info["references"] = self._extract_references(rule, namespaces) - + logger.debug(f"Extracted remediation info for rule: {rule_id}") return remediation_info - + except Exception as e: logger.error(f"Error extracting remediation for rule {rule_id}: {e}") return remediation_info - + def _extract_text_content(self, element) -> str: """Extract clean text content from XML element, handling HTML tags""" if element is None: return "" - + # Get text content and clean up HTML tags - text = etree.tostring(element, method='text', encoding='unicode').strip() - + text = etree.tostring(element, method="text", encoding="unicode").strip() + # Clean up extra whitespace import re - text = re.sub(r'\s+', ' ', text).strip() - + + text = re.sub(r"\s+", " ", text).strip() + return text - + def _extract_remediation_details(self, rule_element, namespaces: Dict) -> Dict: """Extract remediation details from rule element with enhanced Fix Text and OpenSCAP remediation parsing""" remediation = { "type": "manual", - "complexity": "unknown", + "complexity": "unknown", "disruption": "unknown", "description": "", "fix_text": "", "detailed_description": "", "steps": [], "commands": [], - "configuration": [] + "configuration": [], } - + try: # First Priority: Look for SCAP compliance checker "Fix Text" elements - fixtext_elements = rule_element.findall('.//xccdf:fixtext', namespaces) + fixtext_elements = rule_element.findall(".//xccdf:fixtext", namespaces) if fixtext_elements: logger.debug("Found SCAP compliance checker Fix Text elements") for fixtext in fixtext_elements: @@ -541,16 +590,16 @@ def _extract_remediation_details(self, rule_element, namespaces: Dict) -> Dict: remediation["fix_text"] = fix_content remediation["description"] = fix_content remediation["type"] = "manual" - + # Extract detailed steps from Fix Text parsed_steps = self._parse_remediation_text(fix_content) remediation.update(parsed_steps) - + logger.debug(f"Extracted Fix Text: {fix_content[:100]}...") break # Use first available fix text - + # Second Priority: Look for OpenSCAP Evaluation Report "remediation" elements - remediation_elements = rule_element.findall('.//xccdf:remediation', namespaces) + remediation_elements = rule_element.findall(".//xccdf:remediation", namespaces) if remediation_elements and not remediation["description"]: logger.debug("Found OpenSCAP Evaluation Report remediation elements") for remediation_elem in remediation_elements: @@ -558,28 +607,30 @@ def _extract_remediation_details(self, rule_element, namespaces: Dict) -> Dict: if remediation_content: remediation["description"] = remediation_content remediation["type"] = "manual" - + # Extract steps from remediation content parsed_steps = self._parse_remediation_text(remediation_content) remediation.update(parsed_steps) - - logger.debug(f"Extracted OpenSCAP remediation: {remediation_content[:100]}...") + + logger.debug( + f"Extracted OpenSCAP remediation: {remediation_content[:100]}..." + ) break - + # Third Priority: Look for fix elements with different strategies - fix_elements = rule_element.findall('.//xccdf:fix', namespaces) + fix_elements = rule_element.findall(".//xccdf:fix", namespaces) for fix_elem in fix_elements: - strategy = fix_elem.get('strategy', 'unknown') - complexity = fix_elem.get('complexity', 'unknown') - disruption = fix_elem.get('disruption', 'unknown') - + strategy = fix_elem.get("strategy", "unknown") + complexity = fix_elem.get("complexity", "unknown") + disruption = fix_elem.get("disruption", "unknown") + remediation["complexity"] = complexity remediation["disruption"] = disruption - + fix_content = self._extract_text_content(fix_elem) if fix_content and not remediation["description"]: logger.debug("Found xccdf:fix element") - if strategy in ['configure', 'patch']: + if strategy in ["configure", "patch"]: remediation["type"] = "automatic" # Extract configuration commands parsed_config = self._parse_configuration_commands(fix_content) @@ -588,77 +639,95 @@ def _extract_remediation_details(self, rule_element, namespaces: Dict) -> Dict: remediation["type"] = "manual" parsed_steps = self._parse_remediation_text(fix_content) remediation.update(parsed_steps) - + remediation["description"] = fix_content - + # Fourth Priority: Look for detailed description elements description_selectors = [ - './/xccdf:description', - './/description', - './/long_name', - './/detail' + ".//xccdf:description", + ".//description", + ".//long_name", + ".//detail", ] - + for selector in description_selectors: desc_elements = rule_element.findall(selector, namespaces) for desc_elem in desc_elements: detailed_desc = self._extract_text_content(desc_elem) - if detailed_desc and len(detailed_desc) > len(remediation.get("detailed_description", "")): + if detailed_desc and len(detailed_desc) > len( + remediation.get("detailed_description", "") + ): remediation["detailed_description"] = detailed_desc logger.debug(f"Found detailed description: {detailed_desc[:100]}...") - + # Fifth Priority: Look for check-content for additional context - check_elements = rule_element.findall('.//xccdf:check-content', namespaces) + check_elements = rule_element.findall(".//xccdf:check-content", namespaces) for check_elem in check_elements: check_content = self._extract_text_content(check_elem) if check_content and not remediation["description"]: # Use check content as description if no other description available remediation["description"] = f"Ensure: {check_content}" logger.debug("Using check-content as fallback description") - + # Enhanced parsing for specific compliance frameworks self._extract_framework_specific_remediation(rule_element, namespaces, remediation) - + return remediation - + except Exception as e: logger.error(f"Error extracting remediation details: {e}") return remediation - + def _parse_remediation_text(self, text: str) -> Dict: """Parse remediation text to extract structured steps and commands""" steps = [] commands = [] configuration = [] - + if not text: return {"steps": steps, "commands": commands, "configuration": configuration} - + try: # Split text into lines for processing - lines = [line.strip() for line in text.split('\n') if line.strip()] - + lines = [line.strip() for line in text.split("\n") if line.strip()] + current_step = "" for line in lines: # Detect commands (lines that look like shell commands) if self._is_command_line(line): - commands.append({ - "command": line, - "type": "shell", - "description": current_step or "Execute command" - }) + commands.append( + { + "command": line, + "type": "shell", + "description": current_step or "Execute command", + } + ) current_step = "" # Detect configuration entries elif self._is_configuration_line(line): - configuration.append({ - "setting": line, - "type": "config", - "description": current_step or "Configuration setting" - }) + configuration.append( + { + "setting": line, + "type": "config", + "description": current_step or "Configuration setting", + } + ) current_step = "" # Detect step descriptions - elif line.endswith(':') or any(keyword in line.lower() for keyword in - ['step', 'install', 'configure', 'edit', 'modify', 'ensure', 'set', 'enable', 'disable']): + elif line.endswith(":") or any( + keyword in line.lower() + for keyword in [ + "step", + "install", + "configure", + "edit", + "modify", + "ensure", + "set", + "enable", + "disable", + ] + ): if current_step: steps.append(current_step) current_step = line @@ -668,115 +737,122 @@ def _parse_remediation_text(self, text: str) -> Dict: current_step += " " + line else: current_step = line - + # Add any remaining step if current_step: steps.append(current_step) - + return {"steps": steps, "commands": commands, "configuration": configuration} - + except Exception as e: logger.error(f"Error parsing remediation text: {e}") return {"steps": steps, "commands": commands, "configuration": configuration} - + def _is_command_line(self, line: str) -> bool: """Check if a line looks like a shell command""" command_indicators = [ - 'sudo ', '# ', '$ ', 'yum ', 'apt-get ', 'systemctl ', 'chmod ', 'chown ', - 'grep ', 'sed ', 'awk ', 'echo ', 'cat ', 'vi ', 'nano ', 'service ', - 'mount ', 'umount ', 'iptables ', 'firewall-cmd ', 'sysctl ' + "sudo ", + "# ", + "$ ", + "yum ", + "apt-get ", + "systemctl ", + "chmod ", + "chown ", + "grep ", + "sed ", + "awk ", + "echo ", + "cat ", + "vi ", + "nano ", + "service ", + "mount ", + "umount ", + "iptables ", + "firewall-cmd ", + "sysctl ", ] - + line_lower = line.lower() return any(line_lower.startswith(indicator) for indicator in command_indicators) - + def _is_configuration_line(self, line: str) -> bool: """Check if a line looks like a configuration setting""" - config_patterns = [ - '=', ':', 'yes', 'no', 'true', 'false', 'enabled', 'disabled' - ] - + config_patterns = ["=", ":", "yes", "no", "true", "false", "enabled", "disabled"] + # Lines that contain assignment or common config values - return any(pattern in line.lower() for pattern in config_patterns) and \ - not self._is_command_line(line) and \ - len(line.split()) <= 5 # Config lines are usually short - + return ( + any(pattern in line.lower() for pattern in config_patterns) + and not self._is_command_line(line) + and len(line.split()) <= 5 + ) # Config lines are usually short + def _parse_configuration_commands(self, text: str) -> List[Dict]: """Parse configuration-style commands from fix text""" commands = [] - + try: - lines = [line.strip() for line in text.split('\n') if line.strip()] - + lines = [line.strip() for line in text.split("\n") if line.strip()] + for line in lines: if self._is_command_line(line): - commands.append({ - "command": line, - "type": "shell", - "description": "Automated remediation command" - }) - elif '=' in line or ':' in line: - commands.append({ - "command": line, - "type": "config", - "description": "Configuration setting" - }) - + commands.append( + { + "command": line, + "type": "shell", + "description": "Automated remediation command", + } + ) + elif "=" in line or ":" in line: + commands.append( + {"command": line, "type": "config", "description": "Configuration setting"} + ) + return commands - + except Exception as e: logger.error(f"Error parsing configuration commands: {e}") return commands - + def _extract_references(self, rule_element, namespaces: Dict) -> List[Dict]: """Extract reference information from rule element""" references = [] - + try: # Look for reference elements - ref_elements = rule_element.findall('.//xccdf:reference', namespaces) + ref_elements = rule_element.findall(".//xccdf:reference", namespaces) for ref_elem in ref_elements: - href = ref_elem.get('href', '') + href = ref_elem.get("href", "") text = self._extract_text_content(ref_elem) - + if href or text: - references.append({ - "href": href, - "text": text, - "type": "external" - }) - + references.append({"href": href, "text": text, "type": "external"}) + # Look for ident elements (like CCE, CVE references) - ident_elements = rule_element.findall('.//xccdf:ident', namespaces) + ident_elements = rule_element.findall(".//xccdf:ident", namespaces) for ident_elem in ident_elements: - system = ident_elem.get('system', '') - ident_text = ident_elem.text or '' - + system = ident_elem.get("system", "") + ident_text = ident_elem.text or "" + if ident_text: ref_type = "CCE" if "cce" in system.lower() else "identifier" - references.append({ - "href": system, - "text": ident_text, - "type": ref_type - }) - + references.append({"href": system, "text": ident_text, "type": ref_type}) + return references - + except Exception as e: logger.error(f"Error extracting references: {e}") return references - - def _extract_framework_specific_remediation(self, rule_element, namespaces: Dict, remediation: Dict): + + def _extract_framework_specific_remediation( + self, rule_element, namespaces: Dict, remediation: Dict + ): """Extract remediation from framework-specific elements (DISA STIG, CIS, etc.)""" try: # Look for DISA STIG specific elements - stig_elements = [ - './/stig:fix_text', - './/stig:fixtext', - './/fixtext', - './/fix_text' - ] - + stig_elements = [".//stig:fix_text", ".//stig:fixtext", ".//fixtext", ".//fix_text"] + for selector in stig_elements: try: elements = rule_element.findall(selector, namespaces) @@ -789,15 +865,15 @@ def _extract_framework_specific_remediation(self, rule_element, namespaces: Dict return except: continue - + # Look for CIS Benchmark specific elements cis_elements = [ - './/cis:remediation', - './/benchmark:remediation', - './/remediation_procedure', - './/audit_procedure' + ".//cis:remediation", + ".//benchmark:remediation", + ".//remediation_procedure", + ".//audit_procedure", ] - + for selector in cis_elements: try: elements = rule_element.findall(selector, namespaces) @@ -809,15 +885,15 @@ def _extract_framework_specific_remediation(self, rule_element, namespaces: Dict return except: continue - + # Look for NIST specific elements nist_elements = [ - './/nist:implementation_guidance', - './/implementation_guidance', - './/guidance', - './/supplemental_guidance' + ".//nist:implementation_guidance", + ".//implementation_guidance", + ".//guidance", + ".//supplemental_guidance", ] - + for selector in nist_elements: try: elements = rule_element.findall(selector, namespaces) @@ -828,107 +904,125 @@ def _extract_framework_specific_remediation(self, rule_element, namespaces: Dict logger.debug(f"Found NIST guidance: {nist_content[:100]}...") except: continue - + # Look for generic remediation patterns in text content self._extract_generic_remediation_patterns(rule_element, remediation) - + except Exception as e: logger.error(f"Error extracting framework-specific remediation: {e}") - + def _extract_generic_remediation_patterns(self, rule_element, remediation: Dict): """Extract remediation from common text patterns""" try: # Get all text content from the rule element all_text = self._extract_text_content(rule_element) - + if not all_text: return - + # Look for common remediation keywords and sections remediation_keywords = [ - "fix text:", "remediation:", "to remediate:", "fix procedure:", - "corrective action:", "resolution:", "mitigation:", "solution:", - "to resolve:", "recommended action:", "implementation:" + "fix text:", + "remediation:", + "to remediate:", + "fix procedure:", + "corrective action:", + "resolution:", + "mitigation:", + "solution:", + "to resolve:", + "recommended action:", + "implementation:", ] - - lines = all_text.split('\n') + + lines = all_text.split("\n") current_section = "" remediation_found = False - + for i, line in enumerate(lines): line_lower = line.lower().strip() - + # Check if this line contains a remediation keyword for keyword in remediation_keywords: if keyword in line_lower: # Extract the remediation content that follows remediation_content = "" - + # Get content from the same line (after the keyword) - if ':' in line: - remediation_content = line.split(':', 1)[1].strip() - + if ":" in line: + remediation_content = line.split(":", 1)[1].strip() + # Get content from following lines until we hit another section for j in range(i + 1, min(i + 10, len(lines))): # Look ahead up to 10 lines next_line = lines[j].strip() if not next_line: continue - + # Stop if we hit another section header - if any(stop_word in next_line.lower() for stop_word in - ['vulnerability discussion:', 'check text:', 'references:', 'severity:']): + if any( + stop_word in next_line.lower() + for stop_word in [ + "vulnerability discussion:", + "check text:", + "references:", + "severity:", + ] + ): break - + remediation_content += " " + next_line - + if remediation_content and len(remediation_content.strip()) > 20: if not remediation["description"]: remediation["description"] = remediation_content.strip() elif not remediation["fix_text"]: remediation["fix_text"] = remediation_content.strip() - - logger.debug(f"Found generic remediation pattern: {remediation_content[:100]}...") + + logger.debug( + f"Found generic remediation pattern: {remediation_content[:100]}..." + ) remediation_found = True break - + if remediation_found: break - + except Exception as e: logger.error(f"Error extracting generic remediation patterns: {e}") - + def _get_local_system_info(self) -> Dict: """Get local system information""" try: # Get OS info - with open('/etc/os-release', 'r') as f: + with open("/etc/os-release", "r") as f: os_info = {} for line in f: - if '=' in line: - key, value = line.strip().split('=', 1) + if "=" in line: + key, value = line.strip().split("=", 1) os_info[key] = value.strip('"') - + # Get system stats import psutil - + return { "hostname": os.uname().nodename, - "os_name": os_info.get('NAME', 'Unknown'), - "os_version": os_info.get('VERSION', 'Unknown'), + "os_name": os_info.get("NAME", "Unknown"), + "os_version": os_info.get("VERSION", "Unknown"), "kernel": os.uname().release, "architecture": os.uname().machine, "cpu_count": psutil.cpu_count(), "memory_total": psutil.virtual_memory().total, - "disk_usage": dict(psutil.disk_usage('/')), - "uptime": datetime.now().isoformat() + "disk_usage": dict(psutil.disk_usage("/")), + "uptime": datetime.now().isoformat(), } - + except Exception as e: logger.error(f"Error getting local system info: {e}") return {"error": str(e)} - - def _get_remote_system_info(self, hostname: str, port: int, username: str, - auth_method: str, credential: str) -> Dict: + + def _get_remote_system_info( + self, hostname: str, port: int, username: str, auth_method: str, credential: str + ) -> Dict: """Get remote system information via SSH""" try: ssh = paramiko.SSHClient() @@ -937,27 +1031,27 @@ def _get_remote_system_info(self, hostname: str, port: int, username: str, # Load system and user host keys for validation try: ssh.load_system_host_keys() - ssh.load_host_keys(os.path.expanduser('~/.ssh/known_hosts')) + ssh.load_host_keys(os.path.expanduser("~/.ssh/known_hosts")) except FileNotFoundError: - logger.warning("No known_hosts files found - SSH connections may fail without proper host key management") - + logger.warning( + "No known_hosts files found - SSH connections may fail without proper host key management" + ) + # Connect if auth_method == "password": - ssh.connect(hostname, port=port, username=username, - password=credential, timeout=10) + ssh.connect(hostname, port=port, username=username, password=credential, timeout=10) else: # Handle SSH key using new utility try: key = parse_ssh_key(credential) - ssh.connect(hostname, port=port, username=username, - pkey=key, timeout=10) + ssh.connect(hostname, port=port, username=username, pkey=key, timeout=10) except SSHKeyError as e: logger.error(f"SSH key parsing failed for remote system info: {e}") raise Exception(f"SSH key error: {str(e)}") except Exception as e: logger.error(f"SSH connection failed for remote system info: {e}") raise - + # Execute commands to get system info commands = { "hostname": "hostname", @@ -966,74 +1060,87 @@ def _get_remote_system_info(self, hostname: str, port: int, username: str, "architecture": "uname -m", "uptime": "uptime", "memory": "free -m | grep '^Mem:' | awk '{print $2}'", - "cpu_info": "nproc" + "cpu_info": "nproc", } - + results = {} for key, cmd in commands.items(): stdin, stdout, stderr = ssh.exec_command(cmd) output = stdout.read().decode().strip() results[key] = output - + ssh.close() - + # Parse OS release info os_info = {} - for line in results.get("os_release", "").split('\n'): - if '=' in line: - key, value = line.split('=', 1) + for line in results.get("os_release", "").split("\n"): + if "=" in line: + key, value = line.split("=", 1) os_info[key] = value.strip('"') - + return { "hostname": results.get("hostname", hostname), - "os_name": os_info.get('NAME', 'Unknown'), - "os_version": os_info.get('VERSION', 'Unknown'), + "os_name": os_info.get("NAME", "Unknown"), + "os_version": os_info.get("VERSION", "Unknown"), "kernel": results.get("kernel", "Unknown"), "architecture": results.get("architecture", "Unknown"), "cpu_count": int(results.get("cpu_info", "0")), "memory_mb": int(results.get("memory", "0")), - "uptime": results.get("uptime", "Unknown") + "uptime": results.get("uptime", "Unknown"), } - + except Exception as e: logger.error(f"Error getting remote system info: {e}") return {"error": str(e)} - - def _execute_remote_scan_with_paramiko(self, hostname: str, port: int, username: str, - auth_method: str, credential: str, content_path: str, - profile_id: str, scan_id: str, xml_result: Path, - html_report: Path, arf_result: Path, rule_id: str = None) -> Dict: + + def _execute_remote_scan_with_paramiko( + self, + hostname: str, + port: int, + username: str, + auth_method: str, + credential: str, + content_path: str, + profile_id: str, + scan_id: str, + xml_result: Path, + html_report: Path, + arf_result: Path, + rule_id: str = None, + ) -> Dict: """Execute remote SCAP scan using paramiko for all authentication methods""" try: logger.info(f"Executing remote scan via paramiko: {scan_id} on {hostname}") - + ssh = paramiko.SSHClient() # Security Fix: Use strict host key checking instead of AutoAddPolicy ssh.set_missing_host_key_policy(paramiko.RejectPolicy()) # Load system and user host keys for validation try: ssh.load_system_host_keys() - ssh.load_host_keys(os.path.expanduser('~/.ssh/known_hosts')) + ssh.load_host_keys(os.path.expanduser("~/.ssh/known_hosts")) except FileNotFoundError: - logger.warning("No known_hosts files found - SSH connections may fail without proper host key management") - + logger.warning( + "No known_hosts files found - SSH connections may fail without proper host key management" + ) + # Connect based on authentication method if auth_method == "password": - ssh.connect(hostname, port=port, username=username, - password=credential, timeout=30) + ssh.connect(hostname, port=port, username=username, password=credential, timeout=30) elif auth_method in ["ssh-key", "ssh_key"]: # Handle SSH key authentication using new utility try: # Validate SSH key first validation_result = validate_ssh_key(credential) if not validation_result.is_valid: - logger.error(f"Invalid SSH key for {hostname}: {validation_result.error_message}") + logger.error( + f"Invalid SSH key for {hostname}: {validation_result.error_message}" + ) raise SSHException(f"Invalid SSH key: {validation_result.error_message}") - + # Parse key using unified parser key = parse_ssh_key(credential) - ssh.connect(hostname, port=port, username=username, - pkey=key, timeout=30) + ssh.connect(hostname, port=port, username=username, pkey=key, timeout=30) except SSHKeyError as e: logger.error(f"SSH key parsing failed for {hostname}: {e}") raise SSHException(f"SSH key error: {str(e)}") @@ -1042,81 +1149,85 @@ def _execute_remote_scan_with_paramiko(self, hostname: str, port: int, username: raise else: raise ScanExecutionError(f"Unsupported auth method: {auth_method}") - + # Create remote directory for results remote_results_dir = f"/tmp/openwatch_scan_{scan_id}" ssh.exec_command(f"mkdir -p {remote_results_dir}") - + # Define remote file paths remote_xml = f"{remote_results_dir}/results.xml" remote_html = f"{remote_results_dir}/report.html" remote_arf = f"{remote_results_dir}/results.arf.xml" - + # Transfer SCAP content file to remote host sftp = ssh.open_sftp() remote_content_path = f"{remote_results_dir}/content.xml" - + try: sftp.put(content_path, remote_content_path) logger.info(f"Transferred SCAP content to remote host: {remote_content_path}") except Exception as e: sftp.close() - raise ScanExecutionError(f"Failed to transfer SCAP content to remote host: {str(e)}") - + raise ScanExecutionError( + f"Failed to transfer SCAP content to remote host: {str(e)}" + ) + sftp.close() - + # Build and execute oscap command on remote host with transferred file - oscap_cmd = (f"oscap xccdf eval " - f"--profile {profile_id} " - f"--results {remote_xml} " - f"--report {remote_html} " - f"--results-arf {remote_arf} ") - + oscap_cmd = ( + f"oscap xccdf eval " + f"--profile {profile_id} " + f"--results {remote_xml} " + f"--report {remote_html} " + f"--results-arf {remote_arf} " + ) + # Add rule-specific scanning if rule_id is provided if rule_id: oscap_cmd += f"--rule {rule_id} " logger.info(f"Remote scanning specific rule via paramiko: {rule_id}") - + oscap_cmd += f"{remote_content_path}" - + logger.info(f"Executing remote command: {oscap_cmd}") - + stdin, stdout, stderr = ssh.exec_command(oscap_cmd, timeout=1800) # 30 minutes - + # Wait for command completion and get exit code exit_code = stdout.channel.recv_exit_status() stdout_data = stdout.read().decode() stderr_data = stderr.read().decode() - + logger.info(f"Remote oscap command completed with exit code: {exit_code}") - + # Copy result files back to local system sftp = ssh.open_sftp() - + try: sftp.get(remote_xml, str(xml_result)) logger.info(f"Downloaded results file: {xml_result}") except FileNotFoundError: logger.warning("Results XML file not found on remote host") - + try: sftp.get(remote_html, str(html_report)) logger.info(f"Downloaded report file: {html_report}") except FileNotFoundError: logger.warning("HTML report file not found on remote host") - + try: sftp.get(remote_arf, str(arf_result)) logger.info(f"Downloaded ARF file: {arf_result}") except FileNotFoundError: logger.warning("ARF results file not found on remote host") - + sftp.close() - + # Clean up remote files ssh.exec_command(f"rm -rf {remote_results_dir}") ssh.close() - + # Parse results if XML file exists if xml_result.exists(): scan_results = self._parse_scan_results(str(xml_result), content_path) @@ -1129,24 +1240,26 @@ def _execute_remote_scan_with_paramiko(self, hostname: str, port: int, username: "rules_failed": 0, "rules_error": 0, "score": 0.0, - "failed_rules": [] + "failed_rules": [], + } + + scan_results.update( + { + "scan_id": scan_id, + "scan_type": "remote_paramiko", + "target_host": hostname, + "exit_code": exit_code, + "stdout": stdout_data, + "stderr": stderr_data, + "xml_result": str(xml_result) if xml_result.exists() else None, + "html_report": str(html_report) if html_report.exists() else None, + "arf_result": str(arf_result) if arf_result.exists() else None, } - - scan_results.update({ - "scan_id": scan_id, - "scan_type": "remote_paramiko", - "target_host": hostname, - "exit_code": exit_code, - "stdout": stdout_data, - "stderr": stderr_data, - "xml_result": str(xml_result) if xml_result.exists() else None, - "html_report": str(html_report) if html_report.exists() else None, - "arf_result": str(arf_result) if arf_result.exists() else None - }) - + ) + logger.info(f"Remote paramiko scan completed: {scan_id}") return scan_results - + except Exception as e: logger.error(f"Remote paramiko scan failed: {e}") raise ScanExecutionError(f"Remote scan execution failed: {str(e)}") @@ -1154,4 +1267,4 @@ def _execute_remote_scan_with_paramiko(self, hostname: str, port: int, username: try: ssh.close() except: - pass \ No newline at end of file + pass diff --git a/backend/app/services/secure_automated_fixes.py b/backend/app/services/secure_automated_fixes.py index fefbebea..7003bb13 100644 --- a/backend/app/services/secure_automated_fixes.py +++ b/backend/app/services/secure_automated_fixes.py @@ -25,10 +25,10 @@ from ..database import get_async_db from .command_sandbox import ( - CommandSandboxService, - ExecutionRequest, + CommandSandboxService, + ExecutionRequest, ExecutionStatus, - CommandSecurityLevel + CommandSecurityLevel, ) from .error_classification import AutomatedFix, ErrorSeverity @@ -37,7 +37,7 @@ class SecureAutomatedFix: """Enhanced AutomatedFix with security controls""" - + def __init__(self, legacy_fix: AutomatedFix): self.fix_id = legacy_fix.fix_id self.description = legacy_fix.description @@ -46,41 +46,46 @@ def __init__(self, legacy_fix: AutomatedFix): self.command = legacy_fix.command self.is_safe = legacy_fix.is_safe self.rollback_command = legacy_fix.rollback_command - + # New security properties self.security_level = self._determine_security_level() self.requires_approval = self._requires_approval() self.secure_command_id = self._map_to_secure_command() self.parameters = self._extract_parameters() - + def _determine_security_level(self) -> CommandSecurityLevel: """Determine security level based on command characteristics""" if self.requires_sudo: return CommandSecurityLevel.PRIVILEGED - elif any(dangerous in (self.command or "").lower() - for dangerous in ["rm ", "delete", "modify", "install", "update"]): + elif any( + dangerous in (self.command or "").lower() + for dangerous in ["rm ", "delete", "modify", "install", "update"] + ): return CommandSecurityLevel.MODERATE else: return CommandSecurityLevel.SAFE - + def _requires_approval(self) -> bool: """Determine if fix requires manual approval""" - return (self.requires_sudo or - self.security_level in [CommandSecurityLevel.PRIVILEGED, CommandSecurityLevel.CRITICAL] or - not self.is_safe) - + return ( + self.requires_sudo + or self.security_level + in [CommandSecurityLevel.PRIVILEGED, CommandSecurityLevel.CRITICAL] + or not self.is_safe + ) + def _map_to_secure_command(self) -> Optional[str]: """Map legacy fix to secure command template""" if not self.command: return None - + command_lower = self.command.lower() - + # Map common patterns to secure commands if "systemctl status" in command_lower: return "check_service_status" elif "netstat" in command_lower and "grep" in command_lower: - return "check_network_port" + return "check_network_port" elif "apt-get install" in command_lower and "openscap" in command_lower: return "install_openscap_ubuntu" elif "yum install" in command_lower and "openscap" in command_lower: @@ -90,38 +95,39 @@ def _map_to_secure_command(self) -> Optional[str]: else: logger.warning(f"No secure mapping found for command: {self.command}") return None - + def _extract_parameters(self) -> Dict[str, Any]: """Extract parameters from legacy command""" parameters = {} - + if not self.command: return parameters - + # Extract common parameter patterns import re - + # Service name extraction - service_match = re.search(r'systemctl\s+status\s+([a-zA-Z0-9\-_.]+)', self.command) + service_match = re.search(r"systemctl\s+status\s+([a-zA-Z0-9\-_.]+)", self.command) if service_match: parameters["service_name"] = service_match.group(1) - + # Port extraction - port_match = re.search(r'grep\s+(\d+)', self.command) + port_match = re.search(r"grep\s+(\d+)", self.command) if port_match: parameters["port"] = port_match.group(1) - + return parameters class FixExecutionAudit: """Audit trail for fix executions""" - + def __init__(self): self.audit_entries = [] - - async def log_fix_request(self, fix_id: str, requested_by: str, - target_host: str, justification: str): + + async def log_fix_request( + self, fix_id: str, requested_by: str, target_host: str, justification: str + ): """Log fix execution request""" entry = { "event_type": "fix_requested", @@ -130,13 +136,12 @@ async def log_fix_request(self, fix_id: str, requested_by: str, "target_host": target_host, "justification": justification, "timestamp": datetime.utcnow(), - "event_id": str(uuid.uuid4()) + "event_id": str(uuid.uuid4()), } - + await self._persist_audit_entry(entry) - - async def log_fix_approval(self, request_id: str, approved_by: str, - approval_reason: str): + + async def log_fix_approval(self, request_id: str, approved_by: str, approval_reason: str): """Log fix approval decision""" entry = { "event_type": "fix_approved", @@ -144,11 +149,11 @@ async def log_fix_approval(self, request_id: str, approved_by: str, "approved_by": approved_by, "approval_reason": approval_reason, "timestamp": datetime.utcnow(), - "event_id": str(uuid.uuid4()) + "event_id": str(uuid.uuid4()), } - + await self._persist_audit_entry(entry) - + async def log_fix_execution(self, request_id: str, execution_result: ExecutionRequest): """Log fix execution results""" entry = { @@ -156,18 +161,21 @@ async def log_fix_execution(self, request_id: str, execution_result: ExecutionRe "request_id": request_id, "command_id": execution_result.command_id, "exit_code": execution_result.exit_code, - "execution_duration": (execution_result.completed_at - execution_result.executed_at).total_seconds() if execution_result.completed_at and execution_result.executed_at else None, + "execution_duration": ( + (execution_result.completed_at - execution_result.executed_at).total_seconds() + if execution_result.completed_at and execution_result.executed_at + else None + ), "success": execution_result.status == ExecutionStatus.COMPLETED, "output_length": len(execution_result.output or ""), "error_output_length": len(execution_result.error_output or ""), "timestamp": datetime.utcnow(), - "event_id": str(uuid.uuid4()) + "event_id": str(uuid.uuid4()), } - + await self._persist_audit_entry(entry) - - async def log_fix_rollback(self, request_id: str, rollback_by: str, - rollback_success: bool): + + async def log_fix_rollback(self, request_id: str, rollback_by: str, rollback_success: bool): """Log fix rollback operation""" entry = { "event_type": "fix_rolled_back", @@ -175,17 +183,18 @@ async def log_fix_rollback(self, request_id: str, rollback_by: str, "rollback_by": rollback_by, "success": rollback_success, "timestamp": datetime.utcnow(), - "event_id": str(uuid.uuid4()) + "event_id": str(uuid.uuid4()), } - + await self._persist_audit_entry(entry) - + async def _persist_audit_entry(self, entry: Dict[str, Any]): """Persist audit entry to database""" try: async with get_async_db() as session: # Use the existing audit_logs table structure - audit_sql = text(""" + audit_sql = text( + """ INSERT INTO audit_logs ( event_type, user_id, resource_type, resource_id, action, old_values, new_values, ip_address, @@ -195,47 +204,56 @@ async def _persist_audit_entry(self, entry: Dict[str, Any]): :action, :old_values, :new_values, :ip_address, :user_agent, :timestamp ) - """) - - await session.execute(audit_sql, { - "event_type": entry["event_type"], - "user_id": entry.get("requested_by") or entry.get("approved_by") or entry.get("rollback_by"), - "resource_type": "automated_fix", - "resource_id": entry.get("request_id") or entry.get("fix_id"), - "action": entry["event_type"], - "old_values": None, - "new_values": json.dumps({k: v for k, v in entry.items() if k not in ["event_type", "timestamp"]}), - "ip_address": "system", - "user_agent": "openwatch-secure-fix-executor", - "timestamp": entry["timestamp"] - }) - + """ + ) + + await session.execute( + audit_sql, + { + "event_type": entry["event_type"], + "user_id": entry.get("requested_by") + or entry.get("approved_by") + or entry.get("rollback_by"), + "resource_type": "automated_fix", + "resource_id": entry.get("request_id") or entry.get("fix_id"), + "action": entry["event_type"], + "old_values": None, + "new_values": json.dumps( + {k: v for k, v in entry.items() if k not in ["event_type", "timestamp"]} + ), + "ip_address": "system", + "user_agent": "openwatch-secure-fix-executor", + "timestamp": entry["timestamp"], + }, + ) + await session.commit() - + except Exception as e: logger.error(f"Failed to persist audit entry: {e}") class SecureAutomatedFixExecutor: """Main service for secure automated fix execution""" - + def __init__(self): self.sandbox_service = CommandSandboxService() self.audit_service = FixExecutionAudit() self.pending_approvals = {} - - async def evaluate_fix_options(self, legacy_fixes: List[AutomatedFix], - target_host: str) -> List[Dict[str, Any]]: + + async def evaluate_fix_options( + self, legacy_fixes: List[AutomatedFix], target_host: str + ) -> List[Dict[str, Any]]: """Evaluate legacy fixes and convert to secure options""" secure_options = [] - + for legacy_fix in legacy_fixes: secure_fix = SecureAutomatedFix(legacy_fix) - + # Only include fixes that can be mapped to secure commands if secure_fix.secure_command_id: secure_command = self.sandbox_service.get_command_info(secure_fix.secure_command_id) - + if secure_command: option = { "fix_id": secure_fix.fix_id, @@ -246,7 +264,7 @@ async def evaluate_fix_options(self, legacy_fixes: List[AutomatedFix], "secure_command_id": secure_fix.secure_command_id, "parameters": secure_fix.parameters, "rollback_available": bool(secure_command.rollback_template), - "is_safe": secure_fix.security_level == CommandSecurityLevel.SAFE + "is_safe": secure_fix.security_level == CommandSecurityLevel.SAFE, } secure_options.append(option) else: @@ -254,119 +272,120 @@ async def evaluate_fix_options(self, legacy_fixes: List[AutomatedFix], else: # Create a warning for unmappable fixes logger.warning(f"Legacy fix cannot be securely executed: {legacy_fix.fix_id}") - secure_options.append({ - "fix_id": legacy_fix.fix_id, - "description": f"⚠️ SECURITY BLOCKED: {legacy_fix.description}", - "security_level": "blocked", - "requires_approval": True, - "estimated_time": 0, - "secure_command_id": None, - "parameters": {}, - "rollback_available": False, - "is_safe": False, - "blocked_reason": "Command cannot be safely executed in current security model" - }) - + secure_options.append( + { + "fix_id": legacy_fix.fix_id, + "description": f"⚠️ SECURITY BLOCKED: {legacy_fix.description}", + "security_level": "blocked", + "requires_approval": True, + "estimated_time": 0, + "secure_command_id": None, + "parameters": {}, + "rollback_available": False, + "is_safe": False, + "blocked_reason": "Command cannot be safely executed in current security model", + } + ) + return secure_options - - async def request_fix_execution(self, fix_id: str, secure_command_id: str, - parameters: Dict[str, Any], target_host: str, - requested_by: str, justification: str) -> Dict[str, Any]: + + async def request_fix_execution( + self, + fix_id: str, + secure_command_id: str, + parameters: Dict[str, Any], + target_host: str, + requested_by: str, + justification: str, + ) -> Dict[str, Any]: """Request execution of a secure automated fix""" - + try: # Validate the secure command exists if not self.sandbox_service.get_command_info(secure_command_id): raise ValueError(f"Secure command not found: {secure_command_id}") - + # Request execution through sandbox service request = await self.sandbox_service.request_command_execution( command_id=secure_command_id, parameters=parameters, target_host=target_host, requested_by=requested_by, - justification=justification + justification=justification, ) - + # Log audit trail await self.audit_service.log_fix_request( fix_id=fix_id, requested_by=requested_by, target_host=target_host, - justification=justification + justification=justification, ) - + # Store pending approval if needed if request.status == ExecutionStatus.PENDING_APPROVAL: self.pending_approvals[request.request_id] = { "fix_id": fix_id, "request": request, - "requested_at": datetime.utcnow() + "requested_at": datetime.utcnow(), } - + return { "request_id": request.request_id, "status": request.status.value, "requires_approval": request.status == ExecutionStatus.PENDING_APPROVAL, - "message": "Fix execution requested successfully" + "message": "Fix execution requested successfully", } - + except Exception as e: logger.error(f"Failed to request fix execution for {fix_id}: {e}") return { "request_id": None, "status": "failed", "requires_approval": False, - "message": f"Failed to request fix execution: {str(e)}" + "message": f"Failed to request fix execution: {str(e)}", } - - async def approve_fix_request(self, request_id: str, approved_by: str, - approval_reason: str) -> Dict[str, Any]: + + async def approve_fix_request( + self, request_id: str, approved_by: str, approval_reason: str + ) -> Dict[str, Any]: """Approve a pending fix execution request""" - + try: # Approve through sandbox service success = await self.sandbox_service.approve_request(request_id, approved_by) - + if success: # Log approval await self.audit_service.log_fix_approval( - request_id=request_id, - approved_by=approved_by, - approval_reason=approval_reason + request_id=request_id, approved_by=approved_by, approval_reason=approval_reason ) - + # Remove from pending approvals if request_id in self.pending_approvals: del self.pending_approvals[request_id] - - return { - "success": True, - "message": "Fix execution approved successfully" - } + + return {"success": True, "message": "Fix execution approved successfully"} else: return { "success": False, - "message": "Failed to approve fix execution - request not found or already processed" + "message": "Failed to approve fix execution - request not found or already processed", } - + except Exception as e: logger.error(f"Failed to approve fix request {request_id}: {e}") - return { - "success": False, - "message": f"Failed to approve fix execution: {str(e)}" - } - + return {"success": False, "message": f"Failed to approve fix execution: {str(e)}"} + async def execute_approved_fix(self, request_id: str) -> Dict[str, Any]: """Execute an approved fix in secure sandbox""" - + try: # Execute through sandbox service result = await self.sandbox_service.execute_secure_command(request_id) - + # Log execution results await self.audit_service.log_fix_execution(request_id, result) - + # Prepare response response = { "request_id": request_id, @@ -375,54 +394,55 @@ async def execute_approved_fix(self, request_id: str) -> Dict[str, Any]: "exit_code": result.exit_code, "output": result.output, "error_output": result.error_output, - "execution_time": (result.completed_at - result.executed_at).total_seconds() if result.completed_at and result.executed_at else None, - "rollback_available": result.rollback_available + "execution_time": ( + (result.completed_at - result.executed_at).total_seconds() + if result.completed_at and result.executed_at + else None + ), + "rollback_available": result.rollback_available, } - + if result.status == ExecutionStatus.COMPLETED: response["message"] = "Fix executed successfully" else: response["message"] = f"Fix execution failed: {result.error_output}" - + return response - + except Exception as e: logger.error(f"Failed to execute fix {request_id}: {e}") return { "request_id": request_id, "status": "failed", "success": False, - "message": f"Failed to execute fix: {str(e)}" + "message": f"Failed to execute fix: {str(e)}", } - + async def rollback_fix(self, request_id: str, rollback_by: str) -> Dict[str, Any]: """Rollback a previously executed fix""" - + try: success = await self.sandbox_service.rollback_execution(request_id, rollback_by) - + # Log rollback attempt await self.audit_service.log_fix_rollback(request_id, rollback_by, success) - + return { "success": success, - "message": "Fix rolled back successfully" if success else "Fix rollback failed" + "message": "Fix rolled back successfully" if success else "Fix rollback failed", } - + except Exception as e: logger.error(f"Failed to rollback fix {request_id}: {e}") - return { - "success": False, - "message": f"Failed to rollback fix: {str(e)}" - } - + return {"success": False, "message": f"Failed to rollback fix: {str(e)}"} + async def get_fix_status(self, request_id: str) -> Optional[Dict[str, Any]]: """Get status of a fix execution request""" - + request = self.sandbox_service.get_execution_request(request_id) if not request: return None - + return { "request_id": request.request_id, "command_id": request.command_id, @@ -433,14 +453,14 @@ async def get_fix_status(self, request_id: str) -> Optional[Dict[str, Any]]: "executed_at": request.executed_at.isoformat() if request.executed_at else None, "completed_at": request.completed_at.isoformat() if request.completed_at else None, "exit_code": request.exit_code, - "rollback_available": request.rollback_available + "rollback_available": request.rollback_available, } - + async def list_pending_approvals(self) -> List[Dict[str, Any]]: """List all fixes pending approval""" - + pending = self.sandbox_service.list_pending_approvals() - + return [ { "request_id": req.request_id, @@ -448,16 +468,16 @@ async def list_pending_approvals(self) -> List[Dict[str, Any]]: "target_host": req.target_host, "requested_by": req.requested_by, "justification": req.justification, - "requested_at": self.pending_approvals.get(req.request_id, {}).get("requested_at") + "requested_at": self.pending_approvals.get(req.request_id, {}).get("requested_at"), } for req in pending ] - + async def get_secure_command_catalog(self) -> List[Dict[str, Any]]: """Get catalog of available secure commands""" - + commands = self.sandbox_service.list_available_commands() - + return [ { "command_id": cmd.command_id, @@ -466,23 +486,24 @@ async def get_secure_command_catalog(self) -> List[Dict[str, Any]]: "requires_approval": cmd.requires_approval, "allowed_parameters": cmd.allowed_parameters, "max_execution_time": cmd.max_execution_time, - "rollback_available": bool(cmd.rollback_template) + "rollback_available": bool(cmd.rollback_template), } for cmd in commands ] - + async def cleanup_old_requests(self, max_age_days: int = 30): """Clean up old execution requests""" cutoff_date = datetime.utcnow() - timedelta(days=max_age_days) - + # Clean up pending approvals that are too old expired_requests = [ - req_id for req_id, data in self.pending_approvals.items() + req_id + for req_id, data in self.pending_approvals.items() if data["requested_at"] < cutoff_date ] - + for req_id in expired_requests: del self.pending_approvals[req_id] logger.info(f"Cleaned up expired approval request: {req_id}") - - logger.info(f"Cleaned up {len(expired_requests)} expired approval requests") \ No newline at end of file + + logger.info(f"Cleaned up {len(expired_requests)} expired approval requests") diff --git a/backend/app/services/security_audit_logger.py b/backend/app/services/security_audit_logger.py index 029de58f..8fe990c3 100644 --- a/backend/app/services/security_audit_logger.py +++ b/backend/app/services/security_audit_logger.py @@ -2,6 +2,7 @@ OpenWatch Security Audit Logger Handles secure logging of sensitive error information for audit purposes """ + import logging import json import hashlib @@ -11,61 +12,56 @@ from logging.handlers import RotatingFileHandler from ..models.error_models import SecurityAuditLog, ErrorSeverity + class SecurityAuditLogger: """Secure audit logger for error classification events""" - + def __init__(self, log_directory: str = "/app/logs/security"): self.log_directory = Path(log_directory) self.log_directory.mkdir(parents=True, exist_ok=True) - + # Set up security audit logger - self.logger = logging.getLogger('openwatch.security.audit') + self.logger = logging.getLogger("openwatch.security.audit") self.logger.setLevel(logging.INFO) - + # Prevent duplicate handlers if not self.logger.handlers: self._setup_handlers() - + def _setup_handlers(self): """Set up rotating file handler for security audit logs""" - + # Security audit log file security_log_file = self.log_directory / "security_audit.log" - + # Rotating file handler - 100MB max, keep 10 files security_handler = RotatingFileHandler( - security_log_file, - maxBytes=100 * 1024 * 1024, # 100MB - backupCount=10, - encoding='utf-8' + security_log_file, maxBytes=100 * 1024 * 1024, backupCount=10, encoding="utf-8" # 100MB ) - + # JSON formatter for structured logging security_formatter = SecurityJSONFormatter() security_handler.setFormatter(security_formatter) - + self.logger.addHandler(security_handler) - + # Error classification log file (separate from general security events) error_log_file = self.log_directory / "error_classification.log" - + error_handler = RotatingFileHandler( - error_log_file, - maxBytes=50 * 1024 * 1024, # 50MB - backupCount=5, - encoding='utf-8' + error_log_file, maxBytes=50 * 1024 * 1024, backupCount=5, encoding="utf-8" # 50MB ) error_handler.setFormatter(security_formatter) - + # Create separate logger for error classification - self.error_logger = logging.getLogger('openwatch.security.error_classification') + self.error_logger = logging.getLogger("openwatch.security.error_classification") self.error_logger.setLevel(logging.INFO) self.error_logger.addHandler(error_handler) - + # Prevent propagation to avoid duplicate logs self.logger.propagate = False self.error_logger.propagate = False - + def log_error_classification_event( self, error_code: str, @@ -76,10 +72,10 @@ def log_error_classification_event( session_id: Optional[str] = None, request_path: Optional[str] = None, user_agent: Optional[str] = None, - severity: ErrorSeverity = ErrorSeverity.ERROR + severity: ErrorSeverity = ErrorSeverity.ERROR, ): """Log error classification event with full technical details""" - + # Create audit log entry audit_entry = SecurityAuditLog( event_type="error_classification", @@ -91,164 +87,184 @@ def log_error_classification_event( sanitized_response=sanitized_response, severity=severity, request_path=request_path, - user_agent=self._sanitize_user_agent(user_agent) if user_agent else None + user_agent=self._sanitize_user_agent(user_agent) if user_agent else None, ) - + # Log to error classification log self.error_logger.info( "Error Classification Event", extra={ - 'audit_entry': audit_entry.dict(), - 'event_type': 'error_classification', - 'error_code': error_code, - 'severity': severity.value, - 'user_id_hash': self._hash_value(user_id) if user_id else None, - 'source_ip_hash': self._hash_ip(source_ip) if source_ip else None - } + "audit_entry": audit_entry.dict(), + "event_type": "error_classification", + "error_code": error_code, + "severity": severity.value, + "user_id_hash": self._hash_value(user_id) if user_id else None, + "source_ip_hash": self._hash_ip(source_ip) if source_ip else None, + }, ) - + def log_rate_limit_event( - self, - source_ip: str, - error_count: int, - action_taken: str, - user_id: Optional[str] = None + self, source_ip: str, error_count: int, action_taken: str, user_id: Optional[str] = None ): """Log rate limiting security event""" - + self.logger.warning( "Rate Limit Security Event", extra={ - 'event_type': 'rate_limit_violation', - 'source_ip_hash': self._hash_ip(source_ip), - 'error_count': error_count, - 'action_taken': action_taken, - 'user_id_hash': self._hash_value(user_id) if user_id else None, - 'timestamp': datetime.utcnow().isoformat() - } + "event_type": "rate_limit_violation", + "source_ip_hash": self._hash_ip(source_ip), + "error_count": error_count, + "action_taken": action_taken, + "user_id_hash": self._hash_value(user_id) if user_id else None, + "timestamp": datetime.utcnow().isoformat(), + }, ) - + def log_reconnaissance_attempt( self, source_ip: str, suspicious_patterns: List[str], user_id: Optional[str] = None, - session_id: Optional[str] = None + session_id: Optional[str] = None, ): """Log potential reconnaissance attempt""" - + self.logger.critical( "Potential Reconnaissance Attempt", extra={ - 'event_type': 'reconnaissance_attempt', - 'source_ip_hash': self._hash_ip(source_ip), - 'suspicious_patterns': suspicious_patterns, - 'user_id_hash': self._hash_value(user_id) if user_id else None, - 'session_id': session_id, - 'timestamp': datetime.utcnow().isoformat(), - 'severity': 'critical' - } + "event_type": "reconnaissance_attempt", + "source_ip_hash": self._hash_ip(source_ip), + "suspicious_patterns": suspicious_patterns, + "user_id_hash": self._hash_value(user_id) if user_id else None, + "session_id": session_id, + "timestamp": datetime.utcnow().isoformat(), + "severity": "critical", + }, ) - + def _hash_ip(self, ip_address: str) -> str: """Hash IP address for privacy while maintaining uniqueness""" if not ip_address: return None - + # Use SHA-256 with a salt for consistent hashing salt = "openwatch_security_salt_2024" return hashlib.sha256(f"{salt}{ip_address}".encode()).hexdigest()[:16] - + def _hash_value(self, value: str) -> str: """Hash any sensitive value for logging""" if not value: return None - + salt = "openwatch_audit_salt_2024" return hashlib.sha256(f"{salt}{value}".encode()).hexdigest()[:16] - + def _sanitize_user_agent(self, user_agent: str) -> str: """Sanitize user agent string to remove potentially sensitive information""" if not user_agent: return None - + # Keep only the browser/client type, remove version details import re - + # Common patterns to keep patterns = [ - r'Chrome/[\d\.]+', - r'Firefox/[\d\.]+', - r'Safari/[\d\.]+', - r'Edge/[\d\.]+', - r'curl/[\d\.]+', - r'wget/[\d\.]+', - r'Python-urllib/[\d\.]+', - r'requests/[\d\.]+', + r"Chrome/[\d\.]+", + r"Firefox/[\d\.]+", + r"Safari/[\d\.]+", + r"Edge/[\d\.]+", + r"curl/[\d\.]+", + r"wget/[\d\.]+", + r"Python-urllib/[\d\.]+", + r"requests/[\d\.]+", ] - + sanitized = "unknown" for pattern in patterns: match = re.search(pattern, user_agent, re.IGNORECASE) if match: sanitized = match.group(0) break - + return sanitized class SecurityJSONFormatter(logging.Formatter): """JSON formatter for security audit logs""" - + def format(self, record): """Format log record as JSON""" - + # Create base log entry log_entry = { - 'timestamp': datetime.utcnow().isoformat(), - 'level': record.levelname, - 'logger': record.name, - 'message': record.getMessage(), + "timestamp": datetime.utcnow().isoformat(), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), } - + # Add extra fields if present - if hasattr(record, 'audit_entry'): - log_entry['audit_data'] = record.audit_entry - - if hasattr(record, 'event_type'): - log_entry['event_type'] = record.event_type - - if hasattr(record, 'error_code'): - log_entry['error_code'] = record.error_code - - if hasattr(record, 'severity'): - log_entry['severity'] = record.severity - - if hasattr(record, 'user_id_hash'): - log_entry['user_id_hash'] = record.user_id_hash - - if hasattr(record, 'source_ip_hash'): - log_entry['source_ip_hash'] = record.source_ip_hash - + if hasattr(record, "audit_entry"): + log_entry["audit_data"] = record.audit_entry + + if hasattr(record, "event_type"): + log_entry["event_type"] = record.event_type + + if hasattr(record, "error_code"): + log_entry["error_code"] = record.error_code + + if hasattr(record, "severity"): + log_entry["severity"] = record.severity + + if hasattr(record, "user_id_hash"): + log_entry["user_id_hash"] = record.user_id_hash + + if hasattr(record, "source_ip_hash"): + log_entry["source_ip_hash"] = record.source_ip_hash + # Add any other extra fields for key, value in record.__dict__.items(): - if key not in ['name', 'msg', 'args', 'levelname', 'levelno', 'pathname', - 'filename', 'module', 'exc_info', 'exc_text', 'stack_info', - 'lineno', 'funcName', 'created', 'msecs', 'relativeCreated', - 'thread', 'threadName', 'processName', 'process', 'message', - 'audit_entry', 'event_type', 'error_code', 'severity', - 'user_id_hash', 'source_ip_hash']: + if key not in [ + "name", + "msg", + "args", + "levelname", + "levelno", + "pathname", + "filename", + "module", + "exc_info", + "exc_text", + "stack_info", + "lineno", + "funcName", + "created", + "msecs", + "relativeCreated", + "thread", + "threadName", + "processName", + "process", + "message", + "audit_entry", + "event_type", + "error_code", + "severity", + "user_id_hash", + "source_ip_hash", + ]: log_entry[key] = value - + return json.dumps(log_entry, default=str, ensure_ascii=False) # Global instance for dependency injection _security_audit_logger = None + def get_security_audit_logger() -> SecurityAuditLogger: """Get or create the global security audit logger""" global _security_audit_logger if _security_audit_logger is None: _security_audit_logger = SecurityAuditLogger() - return _security_audit_logger \ No newline at end of file + return _security_audit_logger diff --git a/backend/app/services/security_config.py b/backend/app/services/security_config.py index 99ec7cc8..e430316e 100644 --- a/backend/app/services/security_config.py +++ b/backend/app/services/security_config.py @@ -22,15 +22,17 @@ class ConfigScope(str, Enum): """Configuration scope levels""" - SYSTEM = "system" # System-wide default - ORGANIZATION = "org" # Organization level - HOST_GROUP = "group" # Host group level - HOST = "host" # Individual host level + + SYSTEM = "system" # System-wide default + ORGANIZATION = "org" # Organization level + HOST_GROUP = "group" # Host group level + HOST = "host" # Individual host level @dataclass class SecurityConfigTemplate: """Predefined security configuration templates""" + name: str description: str policy_level: SecurityPolicyLevel @@ -44,14 +46,14 @@ class SecurityConfigTemplate: class SecurityConfigManager: """ Manages security configuration policies with hierarchical inheritance. - + Configuration inheritance order: 1. Host-specific config - 2. Host group config + 2. Host group config 3. Organization config 4. System default config """ - + # Predefined security templates TEMPLATES = { "fips_strict": SecurityConfigTemplate( @@ -62,7 +64,7 @@ class SecurityConfigManager: minimum_rsa_bits=3072, allow_dsa_keys=False, minimum_password_length=16, - require_complex_passwords=True + require_complex_passwords=True, ), "enterprise": SecurityConfigTemplate( name="Enterprise", @@ -72,7 +74,7 @@ class SecurityConfigManager: minimum_rsa_bits=2048, allow_dsa_keys=False, minimum_password_length=12, - require_complex_passwords=True + require_complex_passwords=True, ), "development": SecurityConfigTemplate( name="Development", @@ -82,39 +84,40 @@ class SecurityConfigManager: minimum_rsa_bits=2048, allow_dsa_keys=True, minimum_password_length=8, - require_complex_passwords=False - ) + require_complex_passwords=False, + ), } - + def __init__(self, db: Session): self.db = db self._ensure_default_config() - - def get_effective_config(self, target_id: Optional[str] = None, - target_type: Optional[str] = None) -> SecurityPolicyConfig: + + def get_effective_config( + self, target_id: Optional[str] = None, target_type: Optional[str] = None + ) -> SecurityPolicyConfig: """ Get effective security configuration for a target using inheritance. - + Args: target_id: ID of the target (host, group, etc.) target_type: Type of target ('host', 'group', 'org') - + Returns: SecurityPolicyConfig with resolved settings """ try: config_data = {} - + # Start with system default system_config = self._get_config_by_scope(ConfigScope.SYSTEM) if system_config: config_data.update(system_config) - + # Layer on organization config if available org_config = self._get_config_by_scope(ConfigScope.ORGANIZATION) if org_config: config_data.update(org_config) - + # Layer on group config if target is host or in a group if target_type == "host" and target_id: # Get host's group and apply group config @@ -127,49 +130,56 @@ def get_effective_config(self, target_id: Optional[str] = None, group_config = self._get_config_by_scope(ConfigScope.HOST_GROUP, target_id) if group_config: config_data.update(group_config) - + # Finally, apply target-specific config if target_type == "host" and target_id: host_config = self._get_config_by_scope(ConfigScope.HOST, target_id) if host_config: config_data.update(host_config) - + # Convert to SecurityPolicyConfig return self._dict_to_policy_config(config_data) - + except Exception as e: logger.error(f"Failed to get effective config for {target_type}:{target_id}: {e}") # Return default strict config on error return SecurityPolicyConfig() - - def set_config(self, scope: ConfigScope, config: SecurityPolicyConfig, - target_id: Optional[str] = None, created_by: str = "system") -> bool: + + def set_config( + self, + scope: ConfigScope, + config: SecurityPolicyConfig, + target_id: Optional[str] = None, + created_by: str = "system", + ) -> bool: """ Set security configuration for a specific scope. - + Args: scope: Configuration scope config: Security policy configuration target_id: Target ID (required for host/group scope) created_by: User who created the config - + Returns: bool: Success status """ try: config_data = self._policy_config_to_dict(config) - + # Validate scope-target relationship if scope in [ConfigScope.HOST, ConfigScope.HOST_GROUP] and not target_id: raise ValueError(f"target_id is required for {scope.value} scope") - + if scope in [ConfigScope.SYSTEM, ConfigScope.ORGANIZATION] and target_id: raise ValueError(f"target_id must be null for {scope.value} scope") - + current_time = datetime.utcnow() - + # Upsert configuration - self.db.execute(text(""" + self.db.execute( + text( + """ INSERT INTO security_config (scope, target_id, config_data, created_by, created_at, updated_at) VALUES (:scope, :target_id, :config_data, :created_by, :created_at, :updated_at) @@ -177,45 +187,53 @@ def set_config(self, scope: ConfigScope, config: SecurityPolicyConfig, DO UPDATE SET config_data = :config_data, updated_at = :updated_at - """), { - "scope": scope.value, - "target_id": target_id, - "config_data": json.dumps(config_data), - "created_by": created_by, - "created_at": current_time, - "updated_at": current_time - }) - + """ + ), + { + "scope": scope.value, + "target_id": target_id, + "config_data": json.dumps(config_data), + "created_by": created_by, + "created_at": current_time, + "updated_at": current_time, + }, + ) + self.db.commit() - + logger.info(f"Updated security config for scope={scope.value}, target={target_id}") return True - + except Exception as e: logger.error(f"Failed to set security config: {e}") self.db.rollback() return False - - def apply_template(self, template_name: str, scope: ConfigScope, - target_id: Optional[str] = None, created_by: str = "system") -> bool: + + def apply_template( + self, + template_name: str, + scope: ConfigScope, + target_id: Optional[str] = None, + created_by: str = "system", + ) -> bool: """ Apply a predefined security template. - + Args: template_name: Name of the template to apply scope: Configuration scope target_id: Target ID (if applicable) created_by: User applying the template - + Returns: bool: Success status """ if template_name not in self.TEMPLATES: logger.error(f"Unknown security template: {template_name}") return False - + template = self.TEMPLATES[template_name] - + # Convert template to SecurityPolicyConfig config = SecurityPolicyConfig( policy_level=template.policy_level, @@ -223,74 +241,85 @@ def apply_template(self, template_name: str, scope: ConfigScope, minimum_rsa_bits=template.minimum_rsa_bits, allow_dsa_keys=template.allow_dsa_keys, minimum_password_length=template.minimum_password_length, - require_complex_passwords=template.require_complex_passwords + require_complex_passwords=template.require_complex_passwords, ) - + success = self.set_config(scope, config, target_id, created_by) - + if success: logger.info(f"Applied template '{template_name}' to {scope.value}:{target_id}") - + return success - - def get_config_summary(self, target_id: Optional[str] = None, - target_type: Optional[str] = None) -> Dict[str, Any]: + + def get_config_summary( + self, target_id: Optional[str] = None, target_type: Optional[str] = None + ) -> Dict[str, Any]: """ Get comprehensive configuration summary including inheritance chain. - + Args: target_id: Target ID target_type: Target type - + Returns: Dict with configuration summary and inheritance info """ try: effective_config = self.get_effective_config(target_id, target_type) - + # Get inheritance chain inheritance_chain = [] - + # System level system_config = self._get_config_by_scope(ConfigScope.SYSTEM) if system_config: - inheritance_chain.append({ - "level": "system", - "source": "System Default", - "settings_count": len(system_config) - }) - + inheritance_chain.append( + { + "level": "system", + "source": "System Default", + "settings_count": len(system_config), + } + ) + # Organization level org_config = self._get_config_by_scope(ConfigScope.ORGANIZATION) if org_config: - inheritance_chain.append({ - "level": "organization", - "source": "Organization Policy", - "settings_count": len(org_config) - }) - + inheritance_chain.append( + { + "level": "organization", + "source": "Organization Policy", + "settings_count": len(org_config), + } + ) + # Group level if target_type in ["host", "group"] and target_id: - group_id = target_id if target_type == "group" else self._get_host_group_id(target_id) + group_id = ( + target_id if target_type == "group" else self._get_host_group_id(target_id) + ) if group_id: group_config = self._get_config_by_scope(ConfigScope.HOST_GROUP, group_id) if group_config: - inheritance_chain.append({ - "level": "group", - "source": f"Group {group_id}", - "settings_count": len(group_config) - }) - + inheritance_chain.append( + { + "level": "group", + "source": f"Group {group_id}", + "settings_count": len(group_config), + } + ) + # Host level if target_type == "host" and target_id: host_config = self._get_config_by_scope(ConfigScope.HOST, target_id) if host_config: - inheritance_chain.append({ - "level": "host", - "source": f"Host {target_id}", - "settings_count": len(host_config) - }) - + inheritance_chain.append( + { + "level": "host", + "source": f"Host {target_id}", + "settings_count": len(host_config), + } + ) + return { "effective_config": { "policy_level": effective_config.policy_level.value, @@ -299,16 +328,16 @@ def get_config_summary(self, target_id: Optional[str] = None, "allow_dsa_keys": effective_config.allow_dsa_keys, "minimum_password_length": effective_config.minimum_password_length, "require_complex_passwords": effective_config.require_complex_passwords, - "allowed_key_types": [kt.value for kt in effective_config.allowed_key_types] + "allowed_key_types": [kt.value for kt in effective_config.allowed_key_types], }, "inheritance_chain": inheritance_chain, - "compliance_level": self._assess_compliance_level(effective_config) + "compliance_level": self._assess_compliance_level(effective_config), } - + except Exception as e: logger.error(f"Failed to get config summary: {e}") return {"error": str(e)} - + def list_templates(self) -> List[Dict[str, Any]]: """List all available security configuration templates""" return [ @@ -317,109 +346,125 @@ def list_templates(self) -> List[Dict[str, Any]]: "description": template.description, "policy_level": template.policy_level.value, "enforce_fips": template.enforce_fips, - "recommended_for": self._get_template_recommendation(name) + "recommended_for": self._get_template_recommendation(name), } for name, template in self.TEMPLATES.items() ] - + def validate_config(self, config: SecurityPolicyConfig) -> Tuple[bool, List[str]]: """ Validate security configuration for consistency and best practices. - + Args: config: Configuration to validate - + Returns: Tuple of (is_valid, validation_messages) """ messages = [] is_valid = True - + # FIPS compliance validation if config.enforce_fips: if config.allow_dsa_keys: messages.append("DSA keys are not FIPS compliant and should be disabled") is_valid = False - + if config.minimum_rsa_bits < 2048: messages.append("FIPS requires minimum RSA key size of 2048 bits") is_valid = False - + # Security best practices if config.policy_level == SecurityPolicyLevel.STRICT: if config.minimum_rsa_bits < 3072: messages.append("Strict policy should require RSA keys of 3072+ bits") - + if config.minimum_password_length < 12: messages.append("Strict policy should require passwords of 12+ characters") - + # Consistency checks if not config.require_complex_passwords and config.minimum_password_length < 16: messages.append("Simple passwords should be at least 16 characters long") - + if not messages: messages.append("Configuration meets security requirements") - + return is_valid, messages - + def _ensure_default_config(self): """Ensure system has a default security configuration""" try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT id FROM security_config WHERE scope = 'system' AND target_id IS NULL - """)) - + """ + ) + ) + if not result.fetchone(): # Create default strict configuration default_config = SecurityPolicyConfig() self.set_config(ConfigScope.SYSTEM, default_config, created_by="system") logger.info("Created default system security configuration") - + except Exception as e: logger.error(f"Failed to ensure default config: {e}") - - def _get_config_by_scope(self, scope: ConfigScope, target_id: Optional[str] = None) -> Optional[Dict]: + + def _get_config_by_scope( + self, scope: ConfigScope, target_id: Optional[str] = None + ) -> Optional[Dict]: """Get configuration data for a specific scope""" try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT config_data FROM security_config WHERE scope = :scope AND target_id IS :target_id - """), {"scope": scope.value, "target_id": target_id}) - + """ + ), + {"scope": scope.value, "target_id": target_id}, + ) + row = result.fetchone() if row and row.config_data: return json.loads(row.config_data) - + return None - + except Exception as e: logger.error(f"Failed to get config for scope {scope.value}: {e}") return None - + def _get_host_group_id(self, host_id: str) -> Optional[str]: """Get the primary group ID for a host""" try: - result = self.db.execute(text(""" + result = self.db.execute( + text( + """ SELECT hgm.host_group_id FROM host_group_memberships hgm WHERE hgm.host_id = :host_id LIMIT 1 - """), {"host_id": host_id}) - + """ + ), + {"host_id": host_id}, + ) + row = result.fetchone() return str(row.host_group_id) if row else None - + except Exception as e: logger.error(f"Failed to get host group for {host_id}: {e}") return None - + def _dict_to_policy_config(self, config_data: Dict) -> SecurityPolicyConfig: """Convert dictionary to SecurityPolicyConfig""" try: # Set defaults policy_config = SecurityPolicyConfig() - + # Override with provided values if "policy_level" in config_data: policy_config.policy_level = SecurityPolicyLevel(config_data["policy_level"]) @@ -433,13 +478,13 @@ def _dict_to_policy_config(self, config_data: Dict) -> SecurityPolicyConfig: policy_config.minimum_password_length = config_data["minimum_password_length"] if "require_complex_passwords" in config_data: policy_config.require_complex_passwords = config_data["require_complex_passwords"] - + return policy_config - + except Exception as e: logger.error(f"Failed to convert dict to policy config: {e}") return SecurityPolicyConfig() # Return default on error - + def _policy_config_to_dict(self, config: SecurityPolicyConfig) -> Dict: """Convert SecurityPolicyConfig to dictionary""" return { @@ -450,9 +495,9 @@ def _policy_config_to_dict(self, config: SecurityPolicyConfig) -> Dict: "allow_dsa_keys": config.allow_dsa_keys, "minimum_password_length": config.minimum_password_length, "require_complex_passwords": config.require_complex_passwords, - "allowed_key_types": [kt.value for kt in config.allowed_key_types] + "allowed_key_types": [kt.value for kt in config.allowed_key_types], } - + def _assess_compliance_level(self, config: SecurityPolicyConfig) -> str: """Assess overall compliance level of configuration""" if config.enforce_fips and config.policy_level == SecurityPolicyLevel.STRICT: @@ -461,13 +506,13 @@ def _assess_compliance_level(self, config: SecurityPolicyConfig) -> str: return "medium" else: return "low" - + def _get_template_recommendation(self, template_name: str) -> str: """Get recommendation for when to use a template""" recommendations = { "fips_strict": "Government, healthcare, financial services requiring FIPS 140-2", "enterprise": "Corporate environments with security requirements", - "development": "Development and testing environments" + "development": "Development and testing environments", } return recommendations.get(template_name, "General use") @@ -475,4 +520,4 @@ def _get_template_recommendation(self, template_name: str) -> str: # Factory function def get_security_config_manager(db: Session) -> SecurityConfigManager: """Factory function to create SecurityConfigManager""" - return SecurityConfigManager(db) \ No newline at end of file + return SecurityConfigManager(db) diff --git a/backend/app/services/semantic_scap_engine.py b/backend/app/services/semantic_scap_engine.py index 2323fe9a..a3ec9a35 100644 --- a/backend/app/services/semantic_scap_engine.py +++ b/backend/app/services/semantic_scap_engine.py @@ -27,36 +27,38 @@ @dataclass class SemanticRule: """Rich semantic representation of a compliance rule""" - name: str # Semantic name (e.g., 'ssh_disable_root_login') - scap_rule_id: str # Original SCAP rule ID - title: str # Human-readable title - compliance_intent: str # What this rule is trying to achieve - business_impact: str # Business impact category - risk_level: str # high, medium, low - frameworks: List[str] # Which frameworks this rule applies to - remediation_complexity: str # simple, moderate, complex - estimated_fix_time: int # Estimated time in minutes - dependencies: List[str] # Other rules that should be fixed first + + name: str # Semantic name (e.g., 'ssh_disable_root_login') + scap_rule_id: str # Original SCAP rule ID + title: str # Human-readable title + compliance_intent: str # What this rule is trying to achieve + business_impact: str # Business impact category + risk_level: str # high, medium, low + frameworks: List[str] # Which frameworks this rule applies to + remediation_complexity: str # simple, moderate, complex + estimated_fix_time: int # Estimated time in minutes + dependencies: List[str] # Other rules that should be fixed first cross_framework_mappings: Dict[str, str] # Framework-specific rule IDs - remediation_available: bool # Whether AEGIS can remediate this - + remediation_available: bool # Whether AEGIS can remediate this + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for serialization""" return asdict(self) -@dataclass +@dataclass class IntelligentScanResult: """Enhanced scan result with semantic intelligence""" + scan_id: str host_id: str - original_results: Dict[str, Any] # Original SCAP results (preserved) + original_results: Dict[str, Any] # Original SCAP results (preserved) semantic_rules: List[SemanticRule] # Semantic analysis framework_compliance_matrix: Dict[str, float] # Cross-framework compliance scores remediation_strategy: Dict[str, Any] # Intelligent remediation recommendations - compliance_trends: Dict[str, Any] # Predicted compliance trends + compliance_trends: Dict[str, Any] # Predicted compliance trends processing_metadata: Dict[str, Any] # Processing information - + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary for API responses""" return { @@ -67,75 +69,70 @@ def to_dict(self) -> Dict[str, Any]: "framework_compliance_matrix": self.framework_compliance_matrix, "remediation_strategy": self.remediation_strategy, "compliance_trends": self.compliance_trends, - "processing_metadata": self.processing_metadata + "processing_metadata": self.processing_metadata, } class SemanticSCAPEngine: """ Transform static SCAP processing into intelligent semantic analysis - + This engine provides the intelligence layer between OpenWatch scanning and AEGIS remediation, enabling universal compliance understanding. """ - + def __init__(self): self.settings = get_settings() - self.aegis_base_url = getattr(self.settings, 'aegis_api_url', 'http://localhost:8001') + self.aegis_base_url = getattr(self.settings, "aegis_api_url", "http://localhost:8001") self._rule_mappings_cache: Dict[str, SemanticRule] = {} self._framework_cache: Dict[str, Any] = {} self._cache_ttl = 3600 # 1 hour cache TTL - + async def process_scan_with_intelligence( - self, - scan_results: Dict[str, Any], - scan_id: str, - host_info: Dict[str, Any] + self, scan_results: Dict[str, Any], scan_id: str, host_info: Dict[str, Any] ) -> IntelligentScanResult: """ Transform raw SCAP results into intelligent compliance insights - + Args: scan_results: Raw SCAP scan results scan_id: Scan identifier host_info: Host information including OS details - + Returns: IntelligentScanResult with semantic analysis """ logger.info(f"Processing scan with semantic intelligence: {scan_id}") start_time = datetime.utcnow() - + try: # 1. Extract semantic understanding from failed rules semantic_rules = await self._extract_semantic_understanding( scan_results.get("failed_rules", []), scan_results.get("rule_details", []), - host_info - ) - - # 2. Map to universal compliance frameworks - framework_mappings = await self._map_to_universal_frameworks( - semantic_rules, host_info + host_info, ) - + + # 2. Map to universal compliance frameworks + framework_mappings = await self._map_to_universal_frameworks(semantic_rules, host_info) + # 3. Analyze cross-framework compliance impact compliance_matrix = await self._analyze_compliance_matrix( semantic_rules, scan_results, framework_mappings ) - + # 4. Generate intelligent remediation strategy remediation_strategy = await self._create_intelligent_remediation_strategy( semantic_rules, host_info, compliance_matrix ) - + # 5. Predict compliance trends (simplified for initial implementation) compliance_trends = await self._predict_compliance_trends( semantic_rules, scan_id, host_info.get("host_id") ) - + processing_time = (datetime.utcnow() - start_time).total_seconds() - + result = IntelligentScanResult( scan_id=scan_id, host_id=host_info.get("host_id", "unknown"), @@ -148,22 +145,28 @@ async def process_scan_with_intelligence( "processing_time_seconds": processing_time, "semantic_rules_count": len(semantic_rules), "frameworks_analyzed": list(compliance_matrix.keys()), - "remediation_available_count": sum(1 for r in semantic_rules if r.remediation_available), - "processed_at": start_time.isoformat() - } + "remediation_available_count": sum( + 1 for r in semantic_rules if r.remediation_available + ), + "processed_at": start_time.isoformat(), + }, ) - + # Store semantic analysis results await self._store_semantic_analysis(result) - - logger.info(f"Semantic analysis complete for scan {scan_id}: " - f"{len(semantic_rules)} rules analyzed, " - f"{len(compliance_matrix)} frameworks evaluated") - + + logger.info( + f"Semantic analysis complete for scan {scan_id}: " + f"{len(semantic_rules)} rules analyzed, " + f"{len(compliance_matrix)} frameworks evaluated" + ) + return result - + except Exception as e: - logger.error(f"Error in semantic SCAP processing for scan {scan_id}: {e}", exc_info=True) + logger.error( + f"Error in semantic SCAP processing for scan {scan_id}: {e}", exc_info=True + ) # Return minimal result to maintain functionality return IntelligentScanResult( scan_id=scan_id, @@ -176,93 +179,81 @@ async def process_scan_with_intelligence( processing_metadata={ "error": str(e), "processing_failed": True, - "fallback_mode": True - } + "fallback_mode": True, + }, ) - + async def _extract_semantic_understanding( - self, - failed_rules: List[Dict], - rule_details: List[Dict], - host_info: Dict + self, failed_rules: List[Dict], rule_details: List[Dict], host_info: Dict ) -> List[SemanticRule]: """Extract semantic meaning from SCAP rule IDs""" - + semantic_rules = [] - + # Create lookup for detailed rule information - rule_details_lookup = { - detail.get("rule_id"): detail for detail in rule_details - } - + rule_details_lookup = {detail.get("rule_id"): detail for detail in rule_details} + for failed_rule in failed_rules: scap_rule_id = failed_rule.get("rule_id", "") - + try: # Get detailed information if available rule_detail = rule_details_lookup.get(scap_rule_id, {}) - + # Extract semantic information using rule pattern matching semantic_rule = await self._map_scap_rule_to_semantic( - scap_rule_id, - rule_detail, - failed_rule.get("severity", "medium"), - host_info + scap_rule_id, rule_detail, failed_rule.get("severity", "medium"), host_info ) - + if semantic_rule: semantic_rules.append(semantic_rule) - + except Exception as e: logger.warning(f"Failed to process rule {scap_rule_id}: {e}") # Create minimal semantic rule to avoid breaking functionality - semantic_rules.append(SemanticRule( - name=self._generate_fallback_rule_name(scap_rule_id), - scap_rule_id=scap_rule_id, - title=rule_detail.get("title", "Unknown Rule"), - compliance_intent="Security compliance rule", - business_impact="security", - risk_level=failed_rule.get("severity", "medium"), - frameworks=["stig"], # Default to STIG - remediation_complexity="unknown", - estimated_fix_time=10, - dependencies=[], - cross_framework_mappings={}, - remediation_available=False - )) - + semantic_rules.append( + SemanticRule( + name=self._generate_fallback_rule_name(scap_rule_id), + scap_rule_id=scap_rule_id, + title=rule_detail.get("title", "Unknown Rule"), + compliance_intent="Security compliance rule", + business_impact="security", + risk_level=failed_rule.get("severity", "medium"), + frameworks=["stig"], # Default to STIG + remediation_complexity="unknown", + estimated_fix_time=10, + dependencies=[], + cross_framework_mappings={}, + remediation_available=False, + ) + ) + logger.info(f"Extracted semantic understanding for {len(semantic_rules)} rules") return semantic_rules - + async def _map_scap_rule_to_semantic( - self, - scap_rule_id: str, - rule_detail: Dict, - severity: str, - host_info: Dict + self, scap_rule_id: str, rule_detail: Dict, severity: str, host_info: Dict ) -> Optional[SemanticRule]: """Map a SCAP rule ID to semantic understanding""" - + # Try to get mapping from AEGIS first - semantic_mapping = await self._query_aegis_for_semantic_mapping( - scap_rule_id, host_info - ) - + semantic_mapping = await self._query_aegis_for_semantic_mapping(scap_rule_id, host_info) + if semantic_mapping: return semantic_mapping - + # Fallback to pattern-based mapping semantic_name = self._extract_semantic_name_from_scap_rule(scap_rule_id) - + # Extract compliance intent from rule title/description compliance_intent = self._extract_compliance_intent(rule_detail) - + # Determine business impact from rule characteristics business_impact = self._determine_business_impact(rule_detail, semantic_name) - + # Estimate remediation complexity remediation_complexity = self._estimate_remediation_complexity(rule_detail) - + return SemanticRule( name=semantic_name, scap_rule_id=scap_rule_id, @@ -275,59 +266,59 @@ async def _map_scap_rule_to_semantic( estimated_fix_time=self._estimate_fix_time(remediation_complexity), dependencies=[], cross_framework_mappings={}, - remediation_available=False # Will be updated later + remediation_available=False, # Will be updated later ) - + def _extract_semantic_name_from_scap_rule(self, scap_rule_id: str) -> str: """Extract semantic name from SCAP rule ID using pattern matching""" - + # Common SCAP rule ID patterns and their semantic mappings patterns = { - r'ssh.*root.*login': 'ssh_disable_root_login', - r'ssh.*permit.*root': 'ssh_disable_root_login', - r'password.*min.*length': 'password_minimum_length', - r'password.*length': 'password_minimum_length', - r'password.*digit': 'password_minimum_digits', - r'password.*upper': 'password_minimum_uppercase', - r'password.*lower': 'password_minimum_lowercase', - r'password.*special': 'password_minimum_special_chars', - r'auditd.*enable': 'auditd_service_enabled', - r'audit.*log': 'audit_logging_configured', - r'firewall.*enable': 'firewall_enabled', - r'selinux.*enforc': 'selinux_enforcing_mode', - r'kernel.*modules': 'kernel_module_restrictions', - r'file.*permissions': 'file_permissions_configured', - r'umask': 'umask_configured', - r'cron.*permissions': 'cron_access_restricted' + r"ssh.*root.*login": "ssh_disable_root_login", + r"ssh.*permit.*root": "ssh_disable_root_login", + r"password.*min.*length": "password_minimum_length", + r"password.*length": "password_minimum_length", + r"password.*digit": "password_minimum_digits", + r"password.*upper": "password_minimum_uppercase", + r"password.*lower": "password_minimum_lowercase", + r"password.*special": "password_minimum_special_chars", + r"auditd.*enable": "auditd_service_enabled", + r"audit.*log": "audit_logging_configured", + r"firewall.*enable": "firewall_enabled", + r"selinux.*enforc": "selinux_enforcing_mode", + r"kernel.*modules": "kernel_module_restrictions", + r"file.*permissions": "file_permissions_configured", + r"umask": "umask_configured", + r"cron.*permissions": "cron_access_restricted", } - + # Convert rule ID to lowercase for pattern matching rule_id_lower = scap_rule_id.lower() - + for pattern, semantic_name in patterns.items(): if re.search(pattern, rule_id_lower): return semantic_name - + # Generate a fallback name return self._generate_fallback_rule_name(scap_rule_id) - + def _generate_fallback_rule_name(self, scap_rule_id: str) -> str: """Generate a fallback semantic name from SCAP rule ID""" # Extract meaningful parts from the rule ID # Remove common prefixes and suffixes - clean_id = re.sub(r'xccdf_[^_]+_rule_', '', scap_rule_id) - clean_id = re.sub(r'_rule$', '', clean_id) - clean_id = re.sub(r'[^a-zA-Z0-9_]', '_', clean_id) - clean_id = re.sub(r'_+', '_', clean_id) - clean_id = clean_id.strip('_').lower() - + clean_id = re.sub(r"xccdf_[^_]+_rule_", "", scap_rule_id) + clean_id = re.sub(r"_rule$", "", clean_id) + clean_id = re.sub(r"[^a-zA-Z0-9_]", "_", clean_id) + clean_id = re.sub(r"_+", "_", clean_id) + clean_id = clean_id.strip("_").lower() + return clean_id or "unknown_rule" - + def _extract_compliance_intent(self, rule_detail: Dict) -> str: """Extract compliance intent from rule details""" title = rule_detail.get("title", "").lower() description = rule_detail.get("description", "").lower() - + intent_patterns = { "authentication": ["password", "login", "auth", "credential"], "access_control": ["permission", "access", "privilege", "authorization"], @@ -335,42 +326,42 @@ def _extract_compliance_intent(self, rule_detail: Dict) -> str: "network_security": ["ssh", "network", "port", "firewall", "protocol"], "system_hardening": ["kernel", "module", "service", "daemon"], "data_protection": ["encrypt", "hash", "secure", "protect"], - "compliance_monitoring": ["compliance", "policy", "standard", "requirement"] + "compliance_monitoring": ["compliance", "policy", "standard", "requirement"], } - + text = f"{title} {description}" - + for intent, keywords in intent_patterns.items(): if any(keyword in text for keyword in keywords): return intent - + return "security_compliance" - + def _determine_business_impact(self, rule_detail: Dict, semantic_name: str) -> str: """Determine business impact category""" high_impact = ["authentication", "access_control", "network_security"] medium_impact = ["audit_logging", "system_hardening"] - + compliance_intent = self._extract_compliance_intent(rule_detail) - + if compliance_intent in high_impact: return "high" elif compliance_intent in medium_impact: - return "medium" + return "medium" else: return "low" - + def _determine_applicable_frameworks(self, rule_detail: Dict) -> List[str]: """Determine which compliance frameworks this rule applies to""" # For now, assume most rules apply to common frameworks # This will be enhanced with actual framework mapping return ["stig", "cis", "nist"] - + def _estimate_remediation_complexity(self, rule_detail: Dict) -> str: """Estimate remediation complexity""" remediation = rule_detail.get("remediation", {}) fix_text = remediation.get("fix_text", "").lower() - + if "edit" in fix_text or "configure" in fix_text: return "simple" elif "install" in fix_text or "restart" in fix_text: @@ -379,44 +370,35 @@ def _estimate_remediation_complexity(self, rule_detail: Dict) -> str: return "complex" else: return "simple" # Default to simple - + def _estimate_fix_time(self, complexity: str) -> int: """Estimate fix time in minutes based on complexity""" - time_mapping = { - "simple": 5, - "moderate": 15, - "complex": 30 - } + time_mapping = {"simple": 5, "moderate": 15, "complex": 30} return time_mapping.get(complexity, 10) - + async def _query_aegis_for_semantic_mapping( - self, - scap_rule_id: str, - host_info: Dict + self, scap_rule_id: str, host_info: Dict ) -> Optional[SemanticRule]: """Query AEGIS for semantic rule mapping""" - + try: # Build distribution key for AEGIS query distribution_key = self._build_distribution_key(host_info) - + # Query AEGIS for rule mapping async with httpx.AsyncClient() as client: response = await client.get( f"{self.aegis_base_url}/api/v1/rules/scap-mapping", - params={ - "scap_rule_id": scap_rule_id, - "distribution": distribution_key - }, - timeout=5.0 + params={"scap_rule_id": scap_rule_id, "distribution": distribution_key}, + timeout=5.0, ) - + if response.status_code == 200: mapping_data = response.json() - + if mapping_data.get("semantic_rule"): rule_data = mapping_data["semantic_rule"] - + return SemanticRule( name=rule_data["name"], scap_rule_id=scap_rule_id, @@ -425,103 +407,100 @@ async def _query_aegis_for_semantic_mapping( business_impact=rule_data.get("business_impact", "medium"), risk_level=rule_data.get("severity", "medium"), frameworks=rule_data.get("frameworks", []), - remediation_complexity=rule_data.get("remediation_complexity", "simple"), + remediation_complexity=rule_data.get( + "remediation_complexity", "simple" + ), estimated_fix_time=rule_data.get("estimated_fix_time", 10), dependencies=rule_data.get("dependencies", []), cross_framework_mappings=rule_data.get("cross_framework_mappings", {}), - remediation_available=True + remediation_available=True, ) - + except Exception as e: logger.debug(f"Could not query AEGIS for semantic mapping: {e}") - + return None - + def _build_distribution_key(self, host_info: Dict) -> str: """Build distribution key for AEGIS queries""" dist_name = host_info.get("distribution_name", "") dist_version = host_info.get("distribution_version", "") - + if dist_name and dist_version: return f"{dist_name}{dist_version}" - + # Fallback to legacy OS version os_version = host_info.get("os_version", "") if "rhel" in os_version.lower() or "red hat" in os_version.lower(): - version = re.search(r'\d+', os_version) + version = re.search(r"\d+", os_version) if version: return f"rhel{version.group()}" - + return "rhel9" # Default fallback - + async def _map_to_universal_frameworks( - self, - semantic_rules: List[SemanticRule], - host_info: Dict + self, semantic_rules: List[SemanticRule], host_info: Dict ) -> Dict[str, List[SemanticRule]]: """Map semantic rules to universal compliance frameworks""" - + framework_mappings = {} - + # Query AEGIS for framework information try: async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.aegis_base_url}/api/v1/frameworks", - timeout=5.0 - ) - + response = await client.get(f"{self.aegis_base_url}/api/v1/frameworks", timeout=5.0) + if response.status_code == 200: frameworks_data = response.json() - + for framework_info in frameworks_data: framework_name = framework_info["name"] applicable_rules = [] - + for rule in semantic_rules: if framework_name in rule.frameworks: applicable_rules.append(rule) - + if applicable_rules: framework_mappings[framework_name] = applicable_rules - + except Exception as e: logger.debug(f"Could not query AEGIS frameworks: {e}") - + # Fallback to basic framework mapping for rule in semantic_rules: for framework in rule.frameworks: if framework not in framework_mappings: framework_mappings[framework] = [] framework_mappings[framework].append(rule) - + return framework_mappings - + async def _analyze_compliance_matrix( self, semantic_rules: List[SemanticRule], original_scan_results: Dict, - framework_mappings: Dict[str, List[SemanticRule]] + framework_mappings: Dict[str, List[SemanticRule]], ) -> Dict[str, float]: """Analyze cross-framework compliance scores""" - + compliance_matrix = {} - + # Get total rules from original scan total_rules = original_scan_results.get("rules_total", 0) passed_rules = original_scan_results.get("rules_passed", 0) - + if total_rules == 0: return compliance_matrix - + # Calculate baseline compliance score baseline_score = (passed_rules / total_rules) * 100 - + for framework_name, framework_rules in framework_mappings.items(): # For now, use baseline score with slight variations # This will be enhanced with actual framework-specific analysis framework_failed_count = len(framework_rules) - + if framework_failed_count == 0: compliance_matrix[framework_name] = baseline_score else: @@ -529,43 +508,49 @@ async def _analyze_compliance_matrix( impact_factor = min(framework_failed_count * 2, 20) # Cap at 20% impact estimated_score = max(baseline_score - impact_factor, 0) compliance_matrix[framework_name] = round(estimated_score, 1) - + return compliance_matrix - + async def _create_intelligent_remediation_strategy( self, semantic_rules: List[SemanticRule], host_info: Dict, - compliance_matrix: Dict[str, float] + compliance_matrix: Dict[str, float], ) -> Dict[str, Any]: """Create intelligent remediation strategy""" - + if not semantic_rules: return {} - + # Categorize rules by impact and complexity high_impact_rules = [r for r in semantic_rules if r.business_impact == "high"] - quick_wins = [r for r in semantic_rules if r.remediation_complexity == "simple" and r.estimated_fix_time <= 10] + quick_wins = [ + r + for r in semantic_rules + if r.remediation_complexity == "simple" and r.estimated_fix_time <= 10 + ] complex_rules = [r for r in semantic_rules if r.remediation_complexity == "complex"] - + # Calculate total estimated time total_time = sum(rule.estimated_fix_time for rule in semantic_rules) - + # Determine priority order priority_rules = [] - + # 1. High impact, simple fixes first - priority_rules.extend([r for r in high_impact_rules if r.remediation_complexity == "simple"]) - + priority_rules.extend( + [r for r in high_impact_rules if r.remediation_complexity == "simple"] + ) + # 2. Quick wins priority_rules.extend([r for r in quick_wins if r not in priority_rules]) - + # 3. Remaining high impact rules priority_rules.extend([r for r in high_impact_rules if r not in priority_rules]) - + # 4. Everything else priority_rules.extend([r for r in semantic_rules if r not in priority_rules]) - + strategy = { "total_rules": len(semantic_rules), "estimated_total_time_minutes": total_time, @@ -574,84 +559,99 @@ async def _create_intelligent_remediation_strategy( "priority_order": [r.name for r in priority_rules], "complexity_breakdown": { "simple": len([r for r in semantic_rules if r.remediation_complexity == "simple"]), - "moderate": len([r for r in semantic_rules if r.remediation_complexity == "moderate"]), - "complex": len([r for r in semantic_rules if r.remediation_complexity == "complex"]) + "moderate": len( + [r for r in semantic_rules if r.remediation_complexity == "moderate"] + ), + "complex": len( + [r for r in semantic_rules if r.remediation_complexity == "complex"] + ), }, - "framework_impact_prediction": self._predict_framework_impact(semantic_rules, compliance_matrix), - "remediation_recommendations": self._generate_remediation_recommendations(semantic_rules) + "framework_impact_prediction": self._predict_framework_impact( + semantic_rules, compliance_matrix + ), + "remediation_recommendations": self._generate_remediation_recommendations( + semantic_rules + ), } - + return strategy - + def _predict_framework_impact( - self, - semantic_rules: List[SemanticRule], - current_compliance: Dict[str, float] + self, semantic_rules: List[SemanticRule], current_compliance: Dict[str, float] ) -> Dict[str, Dict[str, float]]: """Predict compliance improvement from fixing rules""" - + impact_prediction = {} - + for framework_name, current_score in current_compliance.items(): framework_rules = [r for r in semantic_rules if framework_name in r.frameworks] - + if framework_rules: # Estimate improvement (simplified calculation) potential_improvement = min(len(framework_rules) * 3, 25) # Cap at 25% predicted_score = min(current_score + potential_improvement, 100) - + impact_prediction[framework_name] = { "current_score": current_score, "predicted_score": predicted_score, "improvement": predicted_score - current_score, - "affected_rules": len(framework_rules) + "affected_rules": len(framework_rules), } - + return impact_prediction - - def _generate_remediation_recommendations(self, semantic_rules: List[SemanticRule]) -> List[str]: + + def _generate_remediation_recommendations( + self, semantic_rules: List[SemanticRule] + ) -> List[str]: """Generate human-readable remediation recommendations""" - + recommendations = [] - + high_impact_count = len([r for r in semantic_rules if r.business_impact == "high"]) quick_wins_count = len([r for r in semantic_rules if r.estimated_fix_time <= 10]) - + if high_impact_count > 0: - recommendations.append(f"Prioritize {high_impact_count} high-impact security rules first") - + recommendations.append( + f"Prioritize {high_impact_count} high-impact security rules first" + ) + if quick_wins_count > 0: - recommendations.append(f"Consider addressing {quick_wins_count} quick-win rules for immediate improvement") - + recommendations.append( + f"Consider addressing {quick_wins_count} quick-win rules for immediate improvement" + ) + total_time = sum(rule.estimated_fix_time for rule in semantic_rules) if total_time <= 30: recommendations.append("All issues can be resolved in under 30 minutes") elif total_time <= 60: recommendations.append("Estimated remediation time: 30-60 minutes") else: - recommendations.append(f"Estimated remediation time: {total_time} minutes - consider batching") - + recommendations.append( + f"Estimated remediation time: {total_time} minutes - consider batching" + ) + return recommendations - + async def _predict_compliance_trends( - self, - semantic_rules: List[SemanticRule], - scan_id: str, - host_id: Optional[str] + self, semantic_rules: List[SemanticRule], scan_id: str, host_id: Optional[str] ) -> Dict[str, Any]: """Predict compliance trends (simplified initial implementation)""" - + # For initial implementation, provide basic trend analysis trends = { "risk_level_distribution": { "high": len([r for r in semantic_rules if r.risk_level == "high"]), "medium": len([r for r in semantic_rules if r.risk_level == "medium"]), - "low": len([r for r in semantic_rules if r.risk_level == "low"]) + "low": len([r for r in semantic_rules if r.risk_level == "low"]), }, "remediation_complexity_trend": { "simple": len([r for r in semantic_rules if r.remediation_complexity == "simple"]), - "moderate": len([r for r in semantic_rules if r.remediation_complexity == "moderate"]), - "complex": len([r for r in semantic_rules if r.remediation_complexity == "complex"]) + "moderate": len( + [r for r in semantic_rules if r.remediation_complexity == "moderate"] + ), + "complex": len( + [r for r in semantic_rules if r.remediation_complexity == "complex"] + ), }, "framework_coverage": { framework: len([r for r in semantic_rules if framework in r.frameworks]) @@ -660,20 +660,22 @@ async def _predict_compliance_trends( "predictions": { "next_scan_recommendation": "Schedule follow-up scan after remediation", "compliance_drift_risk": "low" if len(semantic_rules) < 10 else "medium", - "maintenance_frequency": "monthly" if len(semantic_rules) < 5 else "bi-weekly" - } + "maintenance_frequency": "monthly" if len(semantic_rules) < 5 else "bi-weekly", + }, } - + return trends - + async def _store_semantic_analysis(self, result: IntelligentScanResult): """Store semantic analysis results for future reference""" - + try: db = next(get_db()) try: # Store in semantic_scan_analysis table - db.execute(text(""" + db.execute( + text( + """ INSERT INTO semantic_scan_analysis (scan_id, host_id, semantic_rules_count, frameworks_analyzed, remediation_available_count, processing_metadata, analysis_data, created_at) @@ -686,45 +688,57 @@ async def _store_semantic_analysis(self, result: IntelligentScanResult): processing_metadata = EXCLUDED.processing_metadata, analysis_data = EXCLUDED.analysis_data, updated_at = :created_at - """), { - "scan_id": result.scan_id, - "host_id": result.host_id, - "semantic_rules_count": len(result.semantic_rules), - "frameworks_analyzed": json.dumps(list(result.framework_compliance_matrix.keys())), - "remediation_available_count": result.processing_metadata.get("remediation_available_count", 0), - "processing_metadata": json.dumps(result.processing_metadata), - "analysis_data": json.dumps(result.to_dict()), - "created_at": datetime.utcnow() - }) + """ + ), + { + "scan_id": result.scan_id, + "host_id": result.host_id, + "semantic_rules_count": len(result.semantic_rules), + "frameworks_analyzed": json.dumps( + list(result.framework_compliance_matrix.keys()) + ), + "remediation_available_count": result.processing_metadata.get( + "remediation_available_count", 0 + ), + "processing_metadata": json.dumps(result.processing_metadata), + "analysis_data": json.dumps(result.to_dict()), + "created_at": datetime.utcnow(), + }, + ) db.commit() - + logger.debug(f"Stored semantic analysis for scan {result.scan_id}") - + finally: db.close() - + except Exception as e: logger.warning(f"Failed to store semantic analysis: {e}") - + async def get_semantic_analysis(self, scan_id: str) -> Optional[IntelligentScanResult]: """Retrieve stored semantic analysis""" - + try: db = next(get_db()) try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT analysis_data FROM semantic_scan_analysis WHERE scan_id = :scan_id - """), {"scan_id": scan_id}).fetchone() - + """ + ), + {"scan_id": scan_id}, + ).fetchone() + if result and result.analysis_data: data = json.loads(result.analysis_data) - + # Reconstruct SemanticRule objects semantic_rules = [ SemanticRule(**rule_data) for rule_data in data.get("semantic_rules", []) ] - + return IntelligentScanResult( scan_id=data["scan_id"], host_id=data["host_id"], @@ -733,24 +747,25 @@ async def get_semantic_analysis(self, scan_id: str) -> Optional[IntelligentScanR framework_compliance_matrix=data["framework_compliance_matrix"], remediation_strategy=data["remediation_strategy"], compliance_trends=data["compliance_trends"], - processing_metadata=data["processing_metadata"] + processing_metadata=data["processing_metadata"], ) - + finally: db.close() - + except Exception as e: logger.warning(f"Failed to retrieve semantic analysis: {e}") - + return None # Singleton instance _semantic_scap_engine = None + def get_semantic_scap_engine() -> SemanticSCAPEngine: """Get the global semantic SCAP engine instance""" global _semantic_scap_engine if _semantic_scap_engine is None: _semantic_scap_engine = SemanticSCAPEngine() - return _semantic_scap_engine \ No newline at end of file + return _semantic_scap_engine diff --git a/backend/app/services/session_migration_service.py b/backend/app/services/session_migration_service.py index 323da217..3c06a872 100644 --- a/backend/app/services/session_migration_service.py +++ b/backend/app/services/session_migration_service.py @@ -20,56 +20,52 @@ class SessionMigrationService: """Service for managing zero-downtime session migration""" - + def __init__(self): self.migration_window_hours = 24 # 24-hour overlap for smooth transition self.legacy_secret_key = None # Legacy HS256 secret if needed - + def set_legacy_secret_key(self, legacy_key: str): """Set legacy HS256 secret key for backward compatibility""" self.legacy_secret_key = legacy_key logger.info("Legacy secret key configured for session migration") - + def validate_legacy_token(self, token: str) -> Optional[Dict[str, Any]]: """ Validate legacy HS256 tokens during migration period - + Args: token: JWT token string - + Returns: Token payload if valid, None if invalid """ if not self.legacy_secret_key: return None - + try: # Try to decode with legacy HS256 algorithm payload = jwt.decode( token, self.legacy_secret_key, algorithms=["HS256"], - options={ - "verify_signature": True, - "verify_exp": True, - "verify_iat": True - } + options={"verify_signature": True, "verify_exp": True, "verify_iat": True}, ) - + # Check if token is within migration window - iat = payload.get('iat') + iat = payload.get("iat") if iat: token_age = datetime.utcnow().timestamp() - iat max_age = self.migration_window_hours * 3600 - + if token_age <= max_age: logger.info(f"Legacy token accepted for user: {payload.get('sub')}") return payload else: logger.warning(f"Legacy token expired for user: {payload.get('sub')}") - + return None - + except jwt.ExpiredSignatureError: logger.debug("Legacy token has expired") return None @@ -79,14 +75,14 @@ def validate_legacy_token(self, token: str) -> Optional[Dict[str, Any]]: except Exception as e: logger.error(f"Unexpected error validating legacy token: {e}") return None - + def migrate_user_session(self, legacy_payload: Dict[str, Any]) -> Dict[str, str]: """ Migrate legacy session to new RS256 tokens - + Args: legacy_payload: Validated legacy token payload - + Returns: Dictionary with new access and refresh tokens """ @@ -98,41 +94,41 @@ def migrate_user_session(self, legacy_payload: Dict[str, Any]) -> Dict[str, str] "username": legacy_payload.get("username"), "email": legacy_payload.get("email"), "role": legacy_payload.get("role"), - "mfa_enabled": legacy_payload.get("mfa_enabled", False) + "mfa_enabled": legacy_payload.get("mfa_enabled", False), } - + # Generate new RS256 tokens new_access_token = jwt_manager.create_access_token(user_data) new_refresh_token = jwt_manager.create_refresh_token(user_data) - + logger.info(f"Session migrated for user: {user_data['username']}") - + return { "access_token": new_access_token, "refresh_token": new_refresh_token, "token_type": "bearer", "expires_in": settings.access_token_expire_minutes * 60, - "migrated": True + "migrated": True, } - + except Exception as e: logger.error(f"Failed to migrate session: {e}") raise - + def validate_token_with_migration(self, token: str) -> Dict[str, Any]: """ Validate token with automatic migration support - + Args: token: JWT token string - + Returns: Token payload (migrated if necessary) """ try: # First try with current RS256 validation return jwt_manager.validate_access_token(token) - + except jwt.InvalidTokenError: # If RS256 fails, try legacy HS256 validation legacy_payload = self.validate_legacy_token(token) @@ -140,55 +136,68 @@ def validate_token_with_migration(self, token: str) -> Dict[str, Any]: # Mark payload as requiring migration legacy_payload["_requires_migration"] = True return legacy_payload - + # If both fail, re-raise the original exception raise - + def get_migration_statistics(self) -> Dict[str, Any]: """Get session migration statistics""" return { "migration_window_hours": self.migration_window_hours, "legacy_secret_configured": bool(self.legacy_secret_key), "rs256_active": True, - "migration_status": "active" if self.legacy_secret_key else "completed" + "migration_status": "active" if self.legacy_secret_key else "completed", } - + def check_session_compatibility(self, db: Session) -> Dict[str, Any]: """ Check database for session compatibility requirements - + Args: db: Database session - + Returns: Compatibility status and recommendations """ try: # Check user table schema for MFA fields - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT column_name FROM information_schema.columns WHERE table_name = 'users' AND column_name IN ('mfa_enabled', 'mfa_secret') - """)) - + """ + ) + ) + mfa_columns = [row[0] for row in result.fetchall()] - + # Check for legacy password hashes - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT COUNT(*) FROM users WHERE hashed_password LIKE '$2b$%' -- bcrypt format - """)) - + """ + ) + ) + legacy_password_count = result.scalar() - + # Check for active sessions (rough estimate) recent_login_threshold = datetime.utcnow() - timedelta(hours=24) - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT COUNT(*) FROM users WHERE last_login > :threshold - """), {"threshold": recent_login_threshold}) - + """ + ), + {"threshold": recent_login_threshold}, + ) + active_session_estimate = result.scalar() - + return { "mfa_schema_ready": len(mfa_columns) >= 2, "legacy_passwords_remaining": legacy_password_count, @@ -196,32 +205,32 @@ def check_session_compatibility(self, db: Session) -> Dict[str, Any]: "migration_recommendations": [ "Enable legacy token support during peak hours", "Monitor authentication failures for migration issues", - "Gradually phase out legacy support over 24-48 hours" - ] + "Gradually phase out legacy support over 24-48 hours", + ], } - + except Exception as e: logger.error(f"Failed to check session compatibility: {e}") return { "error": str(e), "migration_recommendations": [ "Manual verification of database schema required", - "Test authentication endpoints before full deployment" - ] + "Test authentication endpoints before full deployment", + ], } - + def create_migration_plan(self, db: Session) -> Dict[str, Any]: """ Create a comprehensive migration plan - + Args: db: Database session - + Returns: Migration plan with phases and timelines """ compatibility = self.check_session_compatibility(db) - + plan = { "migration_phases": [ { @@ -232,8 +241,8 @@ def create_migration_plan(self, db: Session) -> Dict[str, Any]: "Verify database schema compatibility", "Configure legacy secret key", "Enable dual-token validation", - "Set up enhanced monitoring" - ] + "Set up enhanced monitoring", + ], }, { "phase": 2, @@ -243,8 +252,8 @@ def create_migration_plan(self, db: Session) -> Dict[str, Any]: "Deploy RS256 token generation", "Maintain HS256 validation compatibility", "Monitor authentication success rates", - "Begin gradual token migration" - ] + "Begin gradual token migration", + ], }, { "phase": 3, @@ -254,8 +263,8 @@ def create_migration_plan(self, db: Session) -> Dict[str, Any]: "Automatic token refresh with RS256", "Legacy token acceptance window", "Monitor for authentication issues", - "User communication if needed" - ] + "User communication if needed", + ], }, { "phase": 4, @@ -265,9 +274,9 @@ def create_migration_plan(self, db: Session) -> Dict[str, Any]: "Disable legacy token validation", "Remove legacy secret configuration", "Verify all users migrated", - "Final monitoring and validation" - ] - } + "Final monitoring and validation", + ], + }, ], "risk_assessment": { "overall_risk": "LOW", @@ -275,26 +284,27 @@ def create_migration_plan(self, db: Session) -> Dict[str, Any]: "rollback_plan": "Immediate revert to HS256 if issues detected", "monitoring_points": [ "Authentication failure rates", - "Token validation performance", + "Token validation performance", "User session duration", - "Error log patterns" - ] + "Error log patterns", + ], }, "success_criteria": [ "Zero forced user logouts", "Authentication failure rate < 1%", "Token validation performance within SLA", - "Complete migration within 48 hours" + "Complete migration within 48 hours", ], - "compatibility_check": compatibility + "compatibility_check": compatibility, } - + return plan # Global session migration service instance _session_migration_service = None + def get_session_migration_service() -> SessionMigrationService: """Get global session migration service instance""" global _session_migration_service @@ -307,10 +317,10 @@ def get_session_migration_service() -> SessionMigrationService: def validate_token_with_migration_support(token: str) -> Dict[str, Any]: """ Enhanced token validation with migration support - + Args: token: JWT token string - + Returns: Token payload with migration information """ @@ -321,4 +331,4 @@ def validate_token_with_migration_support(token: str) -> Dict[str, Any]: def create_migration_plan_for_deployment(db: Session) -> Dict[str, Any]: """Create migration plan for current deployment""" migration_service = get_session_migration_service() - return migration_service.create_migration_plan(db) \ No newline at end of file + return migration_service.create_migration_plan(db) diff --git a/backend/app/services/ssh_key_service.py b/backend/app/services/ssh_key_service.py index 091d5e7e..9e77dea1 100644 --- a/backend/app/services/ssh_key_service.py +++ b/backend/app/services/ssh_key_service.py @@ -13,14 +13,16 @@ logger = logging.getLogger(__name__) -def extract_ssh_key_metadata(key_content: str, passphrase: Optional[str] = None) -> Dict[str, Optional[str]]: +def extract_ssh_key_metadata( + key_content: str, passphrase: Optional[str] = None +) -> Dict[str, Optional[str]]: """ Extract SSH key metadata for storage and display. - + Args: key_content: SSH private key content as string passphrase: Optional passphrase for encrypted keys - + Returns: Dictionary containing: - fingerprint: SHA256 fingerprint (format: SHA256:base64hash) @@ -32,145 +34,150 @@ def extract_ssh_key_metadata(key_content: str, passphrase: Optional[str] = None) try: # Validate and get comprehensive key information result = validate_ssh_key(key_content, passphrase) - + if not result.is_valid: return { - 'fingerprint': None, - 'key_type': None, - 'key_bits': None, - 'key_comment': None, - 'error': result.error_message or 'Invalid SSH key' + "fingerprint": None, + "key_type": None, + "key_bits": None, + "key_comment": None, + "error": result.error_message or "Invalid SSH key", } - + # Get fingerprint fingerprint_hex = get_key_fingerprint(key_content, passphrase) - + # Format fingerprint as SHA256:base64 (like GitHub/OpenSSH format) fingerprint = None if fingerprint_hex: # Convert hex to base64 for cleaner display import base64 + hex_bytes = bytes.fromhex(fingerprint_hex) - b64_fingerprint = base64.b64encode(hex_bytes).decode('ascii') + b64_fingerprint = base64.b64encode(hex_bytes).decode("ascii") fingerprint = f"SHA256:{b64_fingerprint}" - + # Extract comment from key content if present key_comment = extract_key_comment(key_content) - + return { - 'fingerprint': fingerprint, - 'key_type': result.key_type.value if result.key_type else None, - 'key_bits': str(result.key_size) if result.key_size else None, - 'key_comment': key_comment, - 'error': None + "fingerprint": fingerprint, + "key_type": result.key_type.value if result.key_type else None, + "key_bits": str(result.key_size) if result.key_size else None, + "key_comment": key_comment, + "error": None, } - + except Exception as e: logger.error(f"Failed to extract SSH key metadata: {str(e)}") return { - 'fingerprint': None, - 'key_type': None, - 'key_bits': None, - 'key_comment': None, - 'error': str(e) + "fingerprint": None, + "key_type": None, + "key_bits": None, + "key_comment": None, + "error": str(e), } def extract_key_comment(key_content: str) -> Optional[str]: """ Extract comment/label from SSH key content. - + Args: key_content: SSH key content (private or public) - + Returns: Key comment if found, None otherwise """ try: # Look for public key format in private key content # Some private key files include the public key with comment - lines = key_content.split('\n') + lines = key_content.split("\n") for line in lines: - if line.strip().startswith(('ssh-rsa', 'ssh-ed25519', 'ecdsa-sha2-', 'ssh-dss')): + if line.strip().startswith(("ssh-rsa", "ssh-ed25519", "ecdsa-sha2-", "ssh-dss")): parts = line.strip().split() if len(parts) >= 3: # Third part is usually the comment return parts[2] - + # Look for comment patterns in OpenSSH format comment_patterns = [ r'Comment:\s*"([^"]+)"', # Comment: "description" - r'Comment:\s*([^\s]+)', # Comment: description + r"Comment:\s*([^\s]+)", # Comment: description ] - + for pattern in comment_patterns: match = re.search(pattern, key_content, re.IGNORECASE) if match: return match.group(1) - + except Exception: pass - + return None -def format_key_display_info(fingerprint: Optional[str], key_type: Optional[str], - key_bits: Optional[str], key_comment: Optional[str], - created_date: Optional[str] = None) -> str: +def format_key_display_info( + fingerprint: Optional[str], + key_type: Optional[str], + key_bits: Optional[str], + key_comment: Optional[str], + created_date: Optional[str] = None, +) -> str: """ Format SSH key information for display (similar to GitHub format). - + Args: fingerprint: SHA256 fingerprint key_type: Key type (rsa, ed25519, ecdsa, dsa) key_bits: Key size in bits key_comment: Key comment/label created_date: When the key was added (optional) - + Returns: Formatted string for display """ if not fingerprint: return "No SSH key configured" - + # Build display string parts = [] - + # Add fingerprint (truncated for display) if len(fingerprint) > 20: short_fingerprint = fingerprint[:12] + "..." + fingerprint[-8:] else: short_fingerprint = fingerprint parts.append(short_fingerprint) - + # Add key type and size if key_type: type_display = key_type.upper() if key_bits: type_display += f" {key_bits}-bit" parts.append(f"({type_display})") - + result = " ".join(parts) - + # Add creation date if provided if created_date: result += f"\nAdded on {created_date}" - + # Add comment if available if key_comment: result += f"\nComment: {key_comment}" - + return result def get_key_security_indicator(key_type: Optional[str], key_bits: Optional[str]) -> Tuple[str, str]: """ Get security level indicator for SSH key. - + Args: key_type: Key type (rsa, ed25519, ecdsa, dsa) key_bits: Key size in bits - + Returns: Tuple of (security_level, color) for UI display - security_level: "secure", "acceptable", "deprecated", "rejected" @@ -178,10 +185,10 @@ def get_key_security_indicator(key_type: Optional[str], key_bits: Optional[str]) """ if not key_type: return ("unknown", "info") - + key_type = key_type.lower() bits = int(key_bits) if key_bits and key_bits.isdigit() else 0 - + if key_type == "ed25519": return ("secure", "success") elif key_type == "rsa": @@ -198,5 +205,5 @@ def get_key_security_indicator(key_type: Optional[str], key_bits: Optional[str]) return ("acceptable", "warning") elif key_type == "dsa": return ("deprecated", "error") - - return ("unknown", "info") \ No newline at end of file + + return ("unknown", "info") diff --git a/backend/app/services/ssh_utils.py b/backend/app/services/ssh_utils.py index 68740aaa..fc6ffbbe 100644 --- a/backend/app/services/ssh_utils.py +++ b/backend/app/services/ssh_utils.py @@ -20,6 +20,7 @@ class SSHKeyType(Enum): """Supported SSH key types""" + RSA = "rsa" ED25519 = "ed25519" ECDSA = "ecdsa" @@ -28,6 +29,7 @@ class SSHKeyType(Enum): class SSHKeySecurityLevel(Enum): """Security assessment levels for SSH keys""" + SECURE = "secure" ACCEPTABLE = "acceptable" DEPRECATED = "deprecated" @@ -36,7 +38,7 @@ class SSHKeySecurityLevel(Enum): class SSHKeyValidationResult: """Result of SSH key validation""" - + def __init__( self, is_valid: bool, @@ -45,7 +47,7 @@ def __init__( key_size: Optional[int] = None, error_message: Optional[str] = None, warnings: Optional[list] = None, - recommendations: Optional[list] = None + recommendations: Optional[list] = None, ): self.is_valid = is_valid self.key_type = key_type @@ -58,40 +60,45 @@ def __init__( class SSHKeyError(Exception): """Custom exception for SSH key operations""" + pass def detect_key_type(key_content: str) -> Optional[SSHKeyType]: """ Detect SSH key type based on PEM headers or content analysis. - + Args: key_content: SSH private key content as string - + Returns: SSHKeyType if detected, None if unrecognized """ # Handle both string and bytes/memoryview input (for database compatibility) if isinstance(key_content, (bytes, memoryview)): try: - key_content = key_content.decode('utf-8') if isinstance(key_content, bytes) else key_content.tobytes().decode('utf-8') + key_content = ( + key_content.decode("utf-8") + if isinstance(key_content, bytes) + else key_content.tobytes().decode("utf-8") + ) except (UnicodeDecodeError, AttributeError): return None - + if not isinstance(key_content, str): return None - + key_content = key_content.strip() - + # RSA key patterns if "BEGIN RSA PRIVATE KEY" in key_content: return SSHKeyType.RSA - + # Modern RSA keys in PKCS#8 format if "BEGIN PRIVATE KEY" in key_content: try: # Try to parse as PKCS#8 and determine algorithm - key_bytes = key_content.encode('utf-8') + key_bytes = key_content.encode("utf-8") private_key = serialization.load_pem_private_key(key_bytes, password=None) if isinstance(private_key, rsa.RSAPrivateKey): return SSHKeyType.RSA @@ -103,114 +110,115 @@ def detect_key_type(key_content: str) -> Optional[SSHKeyType]: return SSHKeyType.DSA except Exception: pass - + # OpenSSH format keys (can be RSA, Ed25519, ECDSA, etc.) if "BEGIN OPENSSH PRIVATE KEY" in key_content: # Try to detect key type by attempting to parse with each key class - key_bytes = key_content.encode('utf-8') - + key_bytes = key_content.encode("utf-8") + # Check for Ed25519 identifier first (most specific) if "ssh-ed25519" in key_content or "Ed25519" in key_content: return SSHKeyType.ED25519 - + # Try parsing as each type to detect the actual key type from io import StringIO + key_io = StringIO(key_content) - + try: key_io.seek(0) Ed25519Key.from_private_key(key_io, password=None) return SSHKeyType.ED25519 except Exception: pass - + try: key_io.seek(0) RSAKey.from_private_key(key_io, password=None) return SSHKeyType.RSA except Exception: pass - + try: key_io.seek(0) ECDSAKey.from_private_key(key_io, password=None) return SSHKeyType.ECDSA except Exception: pass - + try: key_io.seek(0) DSSKey.from_private_key(key_io, password=None) return SSHKeyType.DSA except Exception: pass - + # ECDSA keys if "BEGIN EC PRIVATE KEY" in key_content: return SSHKeyType.ECDSA - - # DSA keys + + # DSA keys if "BEGIN DSA PRIVATE KEY" in key_content: return SSHKeyType.DSA - + return None def parse_ssh_key(key_content: str, passphrase: Optional[str] = None) -> paramiko.PKey: """ Parse SSH private key content using appropriate Paramiko key class. - + Args: key_content: SSH private key content as string passphrase: Optional passphrase for encrypted keys - + Returns: Paramiko PKey object - + Raises: SSHKeyError: If key cannot be parsed """ # Handle both string and bytes/memoryview input (for database compatibility) if isinstance(key_content, (bytes, memoryview)): try: - key_content = key_content.decode('utf-8') if isinstance(key_content, bytes) else key_content.tobytes().decode('utf-8') + key_content = ( + key_content.decode("utf-8") + if isinstance(key_content, bytes) + else key_content.tobytes().decode("utf-8") + ) except (UnicodeDecodeError, AttributeError): raise SSHKeyError("Invalid key format - could not decode key content") - + if not isinstance(key_content, str): raise SSHKeyError("Invalid key format - key content must be string") - + key_content = key_content.strip() - + # Try each key type in order of preference - key_classes = [ - (Ed25519Key, "Ed25519"), - (ECDSAKey, "ECDSA"), - (RSAKey, "RSA"), - (DSSKey, "DSA") - ] - + key_classes = [(Ed25519Key, "Ed25519"), (ECDSAKey, "ECDSA"), (RSAKey, "RSA"), (DSSKey, "DSA")] + last_error = None for key_class, key_name in key_classes: try: from io import StringIO + key_io = StringIO(key_content) return key_class.from_private_key(key_io, password=passphrase) except Exception as e: last_error = e logger.debug(f"Failed to parse as {key_name} key: {e}") continue - + raise SSHKeyError(f"Unable to parse SSH key: {last_error}") def get_key_size(pkey: paramiko.PKey) -> Optional[int]: """ Get the size/length of an SSH key. - + Args: pkey: Paramiko PKey object - + Returns: Key size in bits, or None if cannot determine """ @@ -226,70 +234,100 @@ def get_key_size(pkey: paramiko.PKey) -> Optional[int]: return pkey.get_bits() except Exception: pass - + return None -def assess_key_security(key_type: SSHKeyType, key_size: Optional[int]) -> Tuple[SSHKeySecurityLevel, list, list]: +def assess_key_security( + key_type: SSHKeyType, key_size: Optional[int] +) -> Tuple[SSHKeySecurityLevel, list, list]: """ Assess security level of SSH key based on type and size. - + Args: key_type: Type of SSH key key_size: Key size in bits - + Returns: Tuple of (security_level, warnings, recommendations) """ warnings = [] recommendations = [] - + if key_type == SSHKeyType.ED25519: return SSHKeySecurityLevel.SECURE, warnings, recommendations - + elif key_type == SSHKeyType.RSA: if key_size is None: warnings.append("Unable to determine RSA key size") - return SSHKeySecurityLevel.ACCEPTABLE, warnings, ["Verify key size meets security requirements"] - + return ( + SSHKeySecurityLevel.ACCEPTABLE, + warnings, + ["Verify key size meets security requirements"], + ) + if key_size < 2048: - return SSHKeySecurityLevel.REJECTED, ["RSA key size too small"], ["Use RSA-4096 or Ed25519 keys"] + return ( + SSHKeySecurityLevel.REJECTED, + ["RSA key size too small"], + ["Use RSA-4096 or Ed25519 keys"], + ) elif key_size < 3072: - return SSHKeySecurityLevel.DEPRECATED, ["RSA-2048 keys are deprecated"], ["Upgrade to RSA-4096 or Ed25519"] + return ( + SSHKeySecurityLevel.DEPRECATED, + ["RSA-2048 keys are deprecated"], + ["Upgrade to RSA-4096 or Ed25519"], + ) elif key_size < 4096: warnings.append("RSA-3072 keys are acceptable but RSA-4096 is recommended") - return SSHKeySecurityLevel.ACCEPTABLE, warnings, ["Consider upgrading to RSA-4096 or Ed25519"] + return ( + SSHKeySecurityLevel.ACCEPTABLE, + warnings, + ["Consider upgrading to RSA-4096 or Ed25519"], + ) else: return SSHKeySecurityLevel.SECURE, warnings, recommendations - + elif key_type == SSHKeyType.ECDSA: if key_size and key_size >= 256: return SSHKeySecurityLevel.SECURE, warnings, recommendations else: warnings.append("ECDSA key size may be insufficient") - return SSHKeySecurityLevel.ACCEPTABLE, warnings, ["Verify ECDSA curve meets requirements"] - + return ( + SSHKeySecurityLevel.ACCEPTABLE, + warnings, + ["Verify ECDSA curve meets requirements"], + ) + elif key_type == SSHKeyType.DSA: - return SSHKeySecurityLevel.REJECTED, ["DSA keys are deprecated and insecure"], ["Use Ed25519 or RSA-4096 keys"] - + return ( + SSHKeySecurityLevel.REJECTED, + ["DSA keys are deprecated and insecure"], + ["Use Ed25519 or RSA-4096 keys"], + ) + return SSHKeySecurityLevel.ACCEPTABLE, ["Unknown key type security assessment"], [] def validate_ssh_key(key_content: str, passphrase: Optional[str] = None) -> SSHKeyValidationResult: """ Validate SSH private key with comprehensive security assessment. - + Args: key_content: SSH private key content as string passphrase: Optional passphrase for encrypted keys - + Returns: SSHKeyValidationResult with validation details """ # Handle both string and bytes/memoryview input (for database compatibility) if isinstance(key_content, (bytes, memoryview)): try: - key_content = key_content.decode('utf-8') if isinstance(key_content, bytes) else key_content.tobytes().decode('utf-8') + key_content = ( + key_content.decode("utf-8") + if isinstance(key_content, bytes) + else key_content.tobytes().decode("utf-8") + ) except (UnicodeDecodeError, AttributeError): return SSHKeyValidationResult( is_valid=False, @@ -298,9 +336,9 @@ def validate_ssh_key(key_content: str, passphrase: Optional[str] = None) -> SSHK key_size=None, security_level=None, warnings=[], - recommendations=[] + recommendations=[], ) - + if not isinstance(key_content, str): return SSHKeyValidationResult( is_valid=False, @@ -309,47 +347,38 @@ def validate_ssh_key(key_content: str, passphrase: Optional[str] = None) -> SSHK key_size=None, security_level=None, warnings=[], - recommendations=[] + recommendations=[], ) if not key_content or not key_content.strip(): - return SSHKeyValidationResult( - is_valid=False, - error_message="SSH key content is empty" - ) - + return SSHKeyValidationResult(is_valid=False, error_message="SSH key content is empty") + # Detect key type key_type = detect_key_type(key_content) if not key_type: return SSHKeyValidationResult( is_valid=False, - error_message="Unable to detect SSH key type. Supported types: RSA, Ed25519, ECDSA, DSA" + error_message="Unable to detect SSH key type. Supported types: RSA, Ed25519, ECDSA, DSA", ) - + # Parse the key try: pkey = parse_ssh_key(key_content, passphrase) except SSHKeyError as e: - return SSHKeyValidationResult( - is_valid=False, - key_type=key_type, - error_message=str(e) - ) + return SSHKeyValidationResult(is_valid=False, key_type=key_type, error_message=str(e)) except Exception as e: return SSHKeyValidationResult( - is_valid=False, - key_type=key_type, - error_message=f"Failed to parse SSH key: {e}" + is_valid=False, key_type=key_type, error_message=f"Failed to parse SSH key: {e}" ) - + # Get key size key_size = get_key_size(pkey) - + # Assess security security_level, warnings, recommendations = assess_key_security(key_type, key_size) - + # Check if key should be rejected is_valid = security_level != SSHKeySecurityLevel.REJECTED - + return SSHKeyValidationResult( is_valid=is_valid, key_type=key_type, @@ -357,28 +386,32 @@ def validate_ssh_key(key_content: str, passphrase: Optional[str] = None) -> SSHK key_size=key_size, error_message=None if is_valid else "SSH key rejected due to security policy", warnings=warnings, - recommendations=recommendations + recommendations=recommendations, ) def get_key_fingerprint(key_content: str, passphrase: Optional[str] = None) -> Optional[str]: """ Get SSH key fingerprint for identification. - + Args: key_content: SSH private key content passphrase: Optional passphrase for encrypted keys - + Returns: SHA256 fingerprint string or None if unable to generate """ # Handle both string and bytes/memoryview input (for database compatibility) if isinstance(key_content, (bytes, memoryview)): try: - key_content = key_content.decode('utf-8') if isinstance(key_content, bytes) else key_content.tobytes().decode('utf-8') + key_content = ( + key_content.decode("utf-8") + if isinstance(key_content, bytes) + else key_content.tobytes().decode("utf-8") + ) except (UnicodeDecodeError, AttributeError): return None - + if not isinstance(key_content, str): return None try: @@ -391,36 +424,36 @@ def get_key_fingerprint(key_content: str, passphrase: Optional[str] = None) -> O def format_validation_message(result: SSHKeyValidationResult) -> str: """ Format validation result into a user-friendly message. - + Args: result: SSH key validation result - + Returns: Formatted message string """ if not result.is_valid: return f"Invalid SSH key: {result.error_message}" - + message = f"Valid {result.key_type.value.upper()} key" if result.key_size: message += f" ({result.key_size} bits)" - + if result.security_level: message += f" - Security: {result.security_level.value}" - + if result.warnings: message += f"\nWarnings: {'; '.join(result.warnings)}" - + if result.recommendations: message += f"\nRecommendations: {'; '.join(result.recommendations)}" - + return message def recommend_key_type() -> str: """ Return current best practice SSH key recommendation. - + Returns: Recommendation text """ @@ -431,4 +464,4 @@ def recommend_key_type() -> str: "3. ECDSA P-256 or higher (good security, smaller than RSA)\n" "4. RSA-3072 (minimum acceptable for new keys)\n\n" "Avoid: DSA keys (deprecated), RSA keys < 3072 bits" - ) \ No newline at end of file + ) diff --git a/backend/app/services/system_info_sanitization.py b/backend/app/services/system_info_sanitization.py index d1c2e8c1..e7159c51 100644 --- a/backend/app/services/system_info_sanitization.py +++ b/backend/app/services/system_info_sanitization.py @@ -17,9 +17,16 @@ from enum import Enum from ..models.system_models import ( - SystemInfoLevel, ComplianceSystemInfo, OperationalSystemInfo, AdminSystemInfo, - SystemInfoSanitizationContext, SystemInfoFilter, SystemInfoMetadata, - SanitizedSystemValidation, SystemReconnaissancePattern, SystemInfoAuditEvent + SystemInfoLevel, + ComplianceSystemInfo, + OperationalSystemInfo, + AdminSystemInfo, + SystemInfoSanitizationContext, + SystemInfoFilter, + SystemInfoMetadata, + SanitizedSystemValidation, + SystemReconnaissancePattern, + SystemInfoAuditEvent, ) from .error_sanitization import get_error_sanitization_service @@ -28,7 +35,8 @@ class ReconnaissanceDetectionLevel(str, Enum): """Levels of reconnaissance detection sensitivity""" - STRICT = "strict" # Block all technical details + + STRICT = "strict" # Block all technical details MODERATE = "moderate" # Allow some operational info PERMISSIVE = "permissive" # Allow more technical details @@ -36,7 +44,7 @@ class ReconnaissanceDetectionLevel(str, Enum): class SystemInfoSanitizationService: """ Service to sanitize system information and prevent reconnaissance attacks. - + Key Features: 1. System Information Filtering - Remove detailed OS/package information 2. Network Configuration Sanitization - Eliminate internal topology details @@ -44,125 +52,123 @@ class SystemInfoSanitizationService: 4. Compliance Information Protection - Safe exposure of only necessary data 5. Integration with Error Sanitization - Build on existing infrastructure """ - + # Reconnaissance patterns that indicate system fingerprinting attempts RECONNAISSANCE_PATTERNS = [ SystemReconnaissancePattern( pattern_id="os_version_detailed", description="Detailed OS version information", regex_pattern=r'VERSION_ID\s*=\s*["\'][^"\']+["\']', - severity="high" + severity="high", ), SystemReconnaissancePattern( pattern_id="kernel_version_full", description="Full kernel version with build info", - regex_pattern=r'Linux\s+[\w\-\.]+\s+[\d\.\-\w]+\s+#\d+', - severity="high" + regex_pattern=r"Linux\s+[\w\-\.]+\s+[\d\.\-\w]+\s+#\d+", + severity="high", ), SystemReconnaissancePattern( pattern_id="package_enumeration", description="Package version enumeration", - regex_pattern=r'(rpm|dpkg|yum|apt)\s+(list|query|show)', - severity="medium" + regex_pattern=r"(rpm|dpkg|yum|apt)\s+(list|query|show)", + severity="medium", ), SystemReconnaissancePattern( pattern_id="network_interfaces", description="Network interface configuration", - regex_pattern=r'(eth\d+|wlan\d+|enp\d+s\d+):\s+.*inet\s+[\d\.]+', - severity="high" + regex_pattern=r"(eth\d+|wlan\d+|enp\d+s\d+):\s+.*inet\s+[\d\.]+", + severity="high", ), SystemReconnaissancePattern( pattern_id="running_services", description="Running service enumeration", - regex_pattern=r'systemctl\s+(status|list-units|show)', - severity="medium" + regex_pattern=r"systemctl\s+(status|list-units|show)", + severity="medium", ), SystemReconnaissancePattern( pattern_id="system_architecture", description="Detailed system architecture info", - regex_pattern=r'Architecture:\s+x86_64|aarch64|armv7l', - severity="low" + regex_pattern=r"Architecture:\s+x86_64|aarch64|armv7l", + severity="low", ), SystemReconnaissancePattern( pattern_id="hostname_disclosure", description="Internal hostname disclosure", - regex_pattern=r'hostname:\s+[\w\-\.]+\.internal|\.local|\.corp', - severity="medium" - ) + regex_pattern=r"hostname:\s+[\w\-\.]+\.internal|\.local|\.corp", + severity="medium", + ), ] - + # Safe system information patterns (allowed for compliance) COMPLIANCE_SAFE_PATTERNS = [ - r'Linux', # Generic OS family - r'Windows', # Generic OS family - r'Unix', # Generic OS family - r'compliance', # Compliance-related terms - r'security', # Security-related terms - r'available', # Resource availability (generic) - r'enabled', # Service status (generic) - r'disabled', # Service status (generic) + r"Linux", # Generic OS family + r"Windows", # Generic OS family + r"Unix", # Generic OS family + r"compliance", # Compliance-related terms + r"security", # Security-related terms + r"available", # Resource availability (generic) + r"enabled", # Service status (generic) + r"disabled", # Service status (generic) ] - + # System information fields to always sanitize SENSITIVE_SYSTEM_FIELDS = [ - 'system_details', # Full uname output - 'detailed_os_info', # /etc/os-release content - 'kernel_version', # Specific kernel version - 'installed_packages', # Package list - 'network_configuration', # Network topology - 'running_services', # Service enumeration - 'hostname', # Internal hostnames - 'ip_address', # Internal IP addresses - 'mac_address', # MAC addresses - 'cpu_info', # CPU model/version - 'memory_info', # Detailed memory info - 'disk_info', # Disk configuration - 'mount_points', # Filesystem mounts - 'environment_vars', # Environment variables - 'process_list', # Running processes - 'open_ports', # Network ports - 'firewall_rules', # Security configuration - 'users_list', # System users - 'groups_list', # System groups - 'cron_jobs', # Scheduled tasks - 'ssh_config', # SSH configuration - 'certificates', # SSL/TLS certificates - 'keys_info', # Cryptographic keys + "system_details", # Full uname output + "detailed_os_info", # /etc/os-release content + "kernel_version", # Specific kernel version + "installed_packages", # Package list + "network_configuration", # Network topology + "running_services", # Service enumeration + "hostname", # Internal hostnames + "ip_address", # Internal IP addresses + "mac_address", # MAC addresses + "cpu_info", # CPU model/version + "memory_info", # Detailed memory info + "disk_info", # Disk configuration + "mount_points", # Filesystem mounts + "environment_vars", # Environment variables + "process_list", # Running processes + "open_ports", # Network ports + "firewall_rules", # Security configuration + "users_list", # System users + "groups_list", # System groups + "cron_jobs", # Scheduled tasks + "ssh_config", # SSH configuration + "certificates", # SSL/TLS certificates + "keys_info", # Cryptographic keys ] - + def __init__(self): self.detection_level = ReconnaissanceDetectionLevel.STRICT self.audit_events: List[SystemInfoAuditEvent] = [] self._error_sanitization_service = get_error_sanitization_service() - + def sanitize_system_information( - self, - raw_system_info: Dict[str, Any], - context: SystemInfoSanitizationContext + self, raw_system_info: Dict[str, Any], context: SystemInfoSanitizationContext ) -> Tuple[Dict[str, Any], SystemInfoMetadata]: """ Main sanitization method - removes sensitive system information while preserving compliance-necessary data. - + Args: raw_system_info: Raw system information collected context: Sanitization context with user/access info - + Returns: Tuple of (sanitized_info, metadata) """ try: # Determine appropriate access level access_level = self._determine_access_level(context) - + # Create filter based on access level info_filter = self._create_system_filter(access_level) - + # Detect reconnaissance patterns reconnaissance_detected, triggered_patterns = self._detect_reconnaissance_patterns( raw_system_info ) - + # Apply sanitization based on access level if access_level == SystemInfoLevel.ADMIN and not reconnaissance_detected: sanitized_info = self._sanitize_for_admin(raw_system_info, info_filter) @@ -173,7 +179,7 @@ def sanitize_system_information( else: # Default to basic (most restrictive) sanitized_info = self._sanitize_for_basic(raw_system_info, info_filter) - + # Create metadata metadata = SystemInfoMetadata( collection_timestamp=datetime.utcnow(), @@ -181,87 +187,83 @@ def sanitize_system_information( sanitization_applied=True, sanitization_level=access_level, admin_access_used=(access_level == SystemInfoLevel.ADMIN), - reconnaissance_filtered=reconnaissance_detected + reconnaissance_filtered=reconnaissance_detected, ) - + # Audit the access self._audit_system_info_access( context, access_level, reconnaissance_detected, triggered_patterns ) - + # Log security event logger.info( f"System info sanitized: level={access_level.value}, " f"user={context.user_id}, reconnaissance={reconnaissance_detected}" ) - + return sanitized_info, metadata - + except Exception as e: logger.error(f"System information sanitization failed: {e}") # Return minimal safe info on error return self._create_minimal_safe_info(), self._create_error_metadata() - + def create_compliance_system_info( - self, - raw_info: Dict[str, Any], - context: SystemInfoSanitizationContext + self, raw_info: Dict[str, Any], context: SystemInfoSanitizationContext ) -> ComplianceSystemInfo: """Create compliance-safe system information object""" - + sanitized_info, metadata = self.sanitize_system_information(raw_info, context) - + # Extract safe OS family - os_family = self._extract_safe_os_family(sanitized_info.get('system_details', '')) - + os_family = self._extract_safe_os_family(sanitized_info.get("system_details", "")) + # Extract compliance-relevant information only compliance_info = { - 'scan_capability': sanitized_info.get('scan_capability', 'unknown'), - 'compliance_tools_available': sanitized_info.get('compliance_tools', False), - 'security_features_enabled': sanitized_info.get('security_features', {}), - 'last_validation': metadata.collection_timestamp.isoformat() + "scan_capability": sanitized_info.get("scan_capability", "unknown"), + "compliance_tools_available": sanitized_info.get("compliance_tools", False), + "security_features_enabled": sanitized_info.get("security_features", {}), + "last_validation": metadata.collection_timestamp.isoformat(), } - + return ComplianceSystemInfo( os_family=os_family, compliance_relevant_info=compliance_info, last_updated=metadata.collection_timestamp, - info_level=SystemInfoLevel.COMPLIANCE + info_level=SystemInfoLevel.COMPLIANCE, ) - + def create_sanitized_validation_result( self, raw_system_info: Dict[str, Any], can_proceed: bool, - context: SystemInfoSanitizationContext + context: SystemInfoSanitizationContext, ) -> SanitizedSystemValidation: """Create sanitized validation result for API responses""" - + compliance_info = self.create_compliance_system_info(raw_system_info, context) - + # Determine system compatibility based on safe criteria system_compatible = self._assess_system_compatibility(raw_system_info) - + metadata = SystemInfoMetadata( collection_timestamp=datetime.utcnow(), sanitization_applied=True, sanitization_level=context.access_level, admin_access_used=context.is_admin, - reconnaissance_filtered=True + reconnaissance_filtered=True, ) - + return SanitizedSystemValidation( can_proceed=can_proceed, system_compatible=system_compatible, compliance_info=compliance_info, validation_timestamp=datetime.utcnow(), - metadata=metadata + metadata=metadata, ) - + def integrate_with_error_sanitization( - self, - error_data: Dict[str, Any], - context: SystemInfoSanitizationContext + self, error_data: Dict[str, Any], context: SystemInfoSanitizationContext ) -> Dict[str, Any]: """ Integrate system information sanitization with existing error sanitization. @@ -270,64 +272,63 @@ def integrate_with_error_sanitization( try: # Create a copy to avoid modifying original sanitized_data = error_data.copy() - + # First apply system-specific sanitization if system_info exists - if 'system_info' in error_data: + if "system_info" in error_data: sanitized_system_info, metadata = self.sanitize_system_information( - error_data['system_info'], context + error_data["system_info"], context ) - sanitized_data['system_info'] = sanitized_system_info - + sanitized_data["system_info"] = sanitized_system_info + # Apply existing error sanitization patterns sanitized_error = self._error_sanitization_service.sanitize_error( - sanitized_data, - user_id=context.user_id, - source_ip=context.source_ip + sanitized_data, user_id=context.user_id, source_ip=context.source_ip ) - + # Convert to dict and ensure system_info is preserved result = sanitized_error.dict() - if 'system_info' in sanitized_data: - result['system_info'] = sanitized_data['system_info'] - + if "system_info" in sanitized_data: + result["system_info"] = sanitized_data["system_info"] + return result - + except Exception as e: logger.error(f"Integrated sanitization failed: {e}") # Fallback - create basic sanitized response with system_info if it existed fallback_result = self._error_sanitization_service.sanitize_error( - error_data, - user_id=context.user_id, - source_ip=context.source_ip + error_data, user_id=context.user_id, source_ip=context.source_ip ).dict() - + # Add minimal system info if it was in original - if 'system_info' in error_data: - fallback_result['system_info'] = {'sanitization_error': True, 'access_level': 'basic'} - + if "system_info" in error_data: + fallback_result["system_info"] = { + "sanitization_error": True, + "access_level": "basic", + } + return fallback_result - + def _determine_access_level(self, context: SystemInfoSanitizationContext) -> SystemInfoLevel: """Determine appropriate system information access level""" - + # Admin users get full access (if not reconnaissance) - if context.is_admin and context.user_role in ['SUPER_ADMIN', 'SECURITY_ADMIN']: + if context.is_admin and context.user_role in ["SUPER_ADMIN", "SECURITY_ADMIN"]: return SystemInfoLevel.ADMIN - + # Operational users get operational info - if context.user_role in ['SYSTEM_ADMIN', 'SCAN_OPERATOR']: + if context.user_role in ["SYSTEM_ADMIN", "SCAN_OPERATOR"]: return SystemInfoLevel.OPERATIONAL - + # Compliance users get compliance info - if context.compliance_only or context.user_role in ['COMPLIANCE_OFFICER']: + if context.compliance_only or context.user_role in ["COMPLIANCE_OFFICER"]: return SystemInfoLevel.COMPLIANCE - + # Default to basic (most restrictive) return SystemInfoLevel.BASIC - + def _create_system_filter(self, access_level: SystemInfoLevel) -> SystemInfoFilter: """Create system information filter based on access level""" - + if access_level == SystemInfoLevel.ADMIN: return SystemInfoFilter( allow_os_version=True, @@ -336,7 +337,7 @@ def _create_system_filter(self, access_level: SystemInfoLevel) -> SystemInfoFilt allow_network_config=True, allow_service_info=True, allow_detailed_errors=True, - sanitization_level=access_level + sanitization_level=access_level, ) elif access_level == SystemInfoLevel.OPERATIONAL: return SystemInfoFilter( @@ -346,7 +347,7 @@ def _create_system_filter(self, access_level: SystemInfoLevel) -> SystemInfoFilt allow_network_config=False, allow_service_info=True, # Service status only allow_detailed_errors=False, - sanitization_level=access_level + sanitization_level=access_level, ) elif access_level == SystemInfoLevel.COMPLIANCE: return SystemInfoFilter( @@ -356,210 +357,209 @@ def _create_system_filter(self, access_level: SystemInfoLevel) -> SystemInfoFilt allow_network_config=False, allow_service_info=False, allow_detailed_errors=False, - sanitization_level=access_level + sanitization_level=access_level, ) else: # Basic - most restrictive - return SystemInfoFilter( - sanitization_level=access_level - ) - + return SystemInfoFilter(sanitization_level=access_level) + def _detect_reconnaissance_patterns( - self, - system_info: Dict[str, Any] + self, system_info: Dict[str, Any] ) -> Tuple[bool, List[str]]: """Detect potential reconnaissance patterns in system information""" - + triggered_patterns = [] - + # Convert system info to searchable text system_text = json.dumps(system_info, default=str).lower() - + for pattern in self.RECONNAISSANCE_PATTERNS: if re.search(pattern.regex_pattern, system_text, re.IGNORECASE): triggered_patterns.append(pattern.pattern_id) logger.warning( f"Reconnaissance pattern detected: {pattern.pattern_id} - {pattern.description}" ) - + reconnaissance_detected = len(triggered_patterns) > 0 - + if reconnaissance_detected: # Log as a warning with security context (since security_warning doesn't exist) - security_logger = logging.getLogger('security_audit') + security_logger = logging.getLogger("security_audit") security_logger.warning( f"System reconnaissance detected: {len(triggered_patterns)} patterns triggered", - extra={'patterns': triggered_patterns, 'system_info_keys': list(system_info.keys())} + extra={ + "patterns": triggered_patterns, + "system_info_keys": list(system_info.keys()), + }, ) logger.warning(f"Reconnaissance patterns detected: {triggered_patterns}") - + return reconnaissance_detected, triggered_patterns - + def _sanitize_for_admin( - self, - raw_info: Dict[str, Any], - info_filter: SystemInfoFilter + self, raw_info: Dict[str, Any], info_filter: SystemInfoFilter ) -> Dict[str, Any]: """Sanitize system information for admin access (full details)""" - + # Admins get full access but with audit logging return { - 'system_details': raw_info.get('system_details', ''), - 'os_info': raw_info.get('os_info', {}), - 'kernel_info': raw_info.get('kernel_info', {}), - 'service_status': raw_info.get('service_status', {}), - 'resource_info': raw_info.get('resource_info', {}), - 'network_info': raw_info.get('network_info', {}), - 'security_info': raw_info.get('security_info', {}), - 'compliance_status': raw_info.get('compliance_status', {}), - 'access_level': 'admin' + "system_details": raw_info.get("system_details", ""), + "os_info": raw_info.get("os_info", {}), + "kernel_info": raw_info.get("kernel_info", {}), + "service_status": raw_info.get("service_status", {}), + "resource_info": raw_info.get("resource_info", {}), + "network_info": raw_info.get("network_info", {}), + "security_info": raw_info.get("security_info", {}), + "compliance_status": raw_info.get("compliance_status", {}), + "access_level": "admin", } - + def _sanitize_for_operational( - self, - raw_info: Dict[str, Any], - info_filter: SystemInfoFilter + self, raw_info: Dict[str, Any], info_filter: SystemInfoFilter ) -> Dict[str, Any]: """Sanitize system information for operational access""" - + return { - 'os_family': self._extract_safe_os_family(raw_info.get('system_details', '')), - 'service_status': self._sanitize_service_status(raw_info.get('service_status', {})), - 'resource_availability': self._sanitize_resource_info(raw_info.get('resource_info', {})), - 'compliance_status': raw_info.get('compliance_status', {}), - 'access_level': 'operational' + "os_family": self._extract_safe_os_family(raw_info.get("system_details", "")), + "service_status": self._sanitize_service_status(raw_info.get("service_status", {})), + "resource_availability": self._sanitize_resource_info( + raw_info.get("resource_info", {}) + ), + "compliance_status": raw_info.get("compliance_status", {}), + "access_level": "operational", } - + def _sanitize_for_compliance( - self, - raw_info: Dict[str, Any], - info_filter: SystemInfoFilter + self, raw_info: Dict[str, Any], info_filter: SystemInfoFilter ) -> Dict[str, Any]: """Sanitize system information for compliance access""" - + return { - 'os_family': self._extract_safe_os_family(raw_info.get('system_details', '')), - 'compliance_status': raw_info.get('compliance_status', {}), - 'scan_capability': self._assess_scan_capability(raw_info), - 'security_features': self._extract_safe_security_features(raw_info), - 'access_level': 'compliance' + "os_family": self._extract_safe_os_family(raw_info.get("system_details", "")), + "compliance_status": raw_info.get("compliance_status", {}), + "scan_capability": self._assess_scan_capability(raw_info), + "security_features": self._extract_safe_security_features(raw_info), + "access_level": "compliance", } - + def _sanitize_for_basic( - self, - raw_info: Dict[str, Any], - info_filter: SystemInfoFilter + self, raw_info: Dict[str, Any], info_filter: SystemInfoFilter ) -> Dict[str, Any]: """Sanitize system information for basic access (most restrictive)""" - + return { - 'system_compatible': self._assess_system_compatibility(raw_info), - 'scan_supported': True, # Generic capability indication - 'access_level': 'basic' + "system_compatible": self._assess_system_compatibility(raw_info), + "scan_supported": True, # Generic capability indication + "access_level": "basic", } - + def _extract_safe_os_family(self, system_details: str) -> str: """Extract safe, generic OS family information""" - + system_details_lower = system_details.lower() - - if 'linux' in system_details_lower: - return 'Linux' - elif 'windows' in system_details_lower: - return 'Windows' - elif 'darwin' in system_details_lower or 'macos' in system_details_lower: - return 'macOS' - elif 'freebsd' in system_details_lower or 'openbsd' in system_details_lower or 'netbsd' in system_details_lower or 'unix' in system_details_lower: - return 'Unix' + + if "linux" in system_details_lower: + return "Linux" + elif "windows" in system_details_lower: + return "Windows" + elif "darwin" in system_details_lower or "macos" in system_details_lower: + return "macOS" + elif ( + "freebsd" in system_details_lower + or "openbsd" in system_details_lower + or "netbsd" in system_details_lower + or "unix" in system_details_lower + ): + return "Unix" else: - return 'Unknown' - + return "Unknown" + def _sanitize_service_status(self, service_status: Dict[str, Any]) -> Dict[str, str]: """Sanitize service status information""" - + sanitized = {} for service, status in service_status.items(): # Only expose generic status, not detailed info if isinstance(status, str): - if 'active' in status.lower() or 'running' in status.lower(): - sanitized[service] = 'enabled' - elif 'inactive' in status.lower() or 'stopped' in status.lower(): - sanitized[service] = 'disabled' + if "active" in status.lower() or "running" in status.lower(): + sanitized[service] = "enabled" + elif "inactive" in status.lower() or "stopped" in status.lower(): + sanitized[service] = "disabled" else: - sanitized[service] = 'unknown' + sanitized[service] = "unknown" else: - sanitized[service] = 'unknown' - + sanitized[service] = "unknown" + return sanitized - + def _sanitize_resource_info(self, resource_info: Dict[str, Any]) -> Dict[str, str]: """Sanitize resource availability information""" - + sanitized = {} - + # Disk space - convert to availability categories - if 'disk_space' in resource_info: - disk_mb = resource_info.get('disk_space', 0) + if "disk_space" in resource_info: + disk_mb = resource_info.get("disk_space", 0) if disk_mb > 1000: - sanitized['disk_space'] = 'adequate' + sanitized["disk_space"] = "adequate" elif disk_mb > 500: - sanitized['disk_space'] = 'limited' + sanitized["disk_space"] = "limited" else: - sanitized['disk_space'] = 'insufficient' - + sanitized["disk_space"] = "insufficient" + # Memory - convert to availability categories - if 'memory' in resource_info: - memory_mb = resource_info.get('memory', 0) + if "memory" in resource_info: + memory_mb = resource_info.get("memory", 0) if memory_mb > 1024: - sanitized['memory'] = 'adequate' + sanitized["memory"] = "adequate" elif memory_mb > 512: - sanitized['memory'] = 'limited' + sanitized["memory"] = "limited" else: - sanitized['memory'] = 'insufficient' - + sanitized["memory"] = "insufficient" + return sanitized - + def _assess_scan_capability(self, raw_info: Dict[str, Any]) -> str: """Assess scanning capability without exposing technical details""" - + # Look for indicators of scan capability - system_details = raw_info.get('system_details', '').lower() - - if 'linux' in system_details: - return 'linux_compatible' - elif 'windows' in system_details: - return 'windows_compatible' + system_details = raw_info.get("system_details", "").lower() + + if "linux" in system_details: + return "linux_compatible" + elif "windows" in system_details: + return "windows_compatible" else: - return 'compatibility_unknown' - + return "compatibility_unknown" + def _extract_safe_security_features(self, raw_info: Dict[str, Any]) -> Dict[str, bool]: """Extract safe security feature information""" - + return { - 'security_scanning_supported': True, # Generic capability - 'compliance_tools_available': raw_info.get('openscap_available', False), - 'secure_connection_available': raw_info.get('ssh_available', True) + "security_scanning_supported": True, # Generic capability + "compliance_tools_available": raw_info.get("openscap_available", False), + "secure_connection_available": raw_info.get("ssh_available", True), } - + def _assess_system_compatibility(self, raw_info: Dict[str, Any]) -> bool: """Assess system compatibility for scanning without exposing details""" - + # Basic compatibility check based on safe criteria - system_details = raw_info.get('system_details', '') - + system_details = raw_info.get("system_details", "") + # Compatible if we can identify it as a known OS family - safe_families = ['linux', 'windows', 'unix', 'darwin'] + safe_families = ["linux", "windows", "unix", "darwin"] return any(family in system_details.lower() for family in safe_families) - + def _audit_system_info_access( self, context: SystemInfoSanitizationContext, granted_level: SystemInfoLevel, reconnaissance_detected: bool, - triggered_patterns: List[str] + triggered_patterns: List[str], ): """Audit system information access for security monitoring""" - + audit_event = SystemInfoAuditEvent( event_id=hashlib.md5(f"{context.user_id}{datetime.utcnow()}".encode()).hexdigest(), user_id=context.user_id, @@ -569,35 +569,35 @@ def _audit_system_info_access( admin_access=context.is_admin, reconnaissance_detected=reconnaissance_detected, patterns_triggered=triggered_patterns, - sanitization_applied=True + sanitization_applied=True, ) - + self.audit_events.append(audit_event) - + # Log to security audit system - security_logger = logging.getLogger('security_audit') + security_logger = logging.getLogger("security_audit") security_logger.info( f"System Info Access: user={context.user_id}, level={granted_level.value}, " f"reconnaissance={reconnaissance_detected}", extra={ - 'event_type': 'system_info_access', - 'user_id': context.user_id, - 'source_ip': context.source_ip, - 'access_level': granted_level.value, - 'reconnaissance_detected': reconnaissance_detected, - 'patterns_triggered': triggered_patterns - } + "event_type": "system_info_access", + "user_id": context.user_id, + "source_ip": context.source_ip, + "access_level": granted_level.value, + "reconnaissance_detected": reconnaissance_detected, + "patterns_triggered": triggered_patterns, + }, ) - + def _create_minimal_safe_info(self) -> Dict[str, Any]: """Create minimal safe system information for error cases""" return { - 'system_compatible': True, - 'scan_supported': True, - 'access_level': 'basic', - 'error_recovery': True + "system_compatible": True, + "scan_supported": True, + "access_level": "basic", + "error_recovery": True, } - + def _create_error_metadata(self) -> SystemInfoMetadata: """Create metadata for error cases""" return SystemInfoMetadata( @@ -606,34 +606,34 @@ def _create_error_metadata(self) -> SystemInfoMetadata: sanitization_applied=True, sanitization_level=SystemInfoLevel.BASIC, admin_access_used=False, - reconnaissance_filtered=True + reconnaissance_filtered=True, ) - + def get_audit_summary(self) -> Dict[str, Any]: """Get summary of system information access audit events""" - + total_events = len(self.audit_events) reconnaissance_events = sum(1 for e in self.audit_events if e.reconnaissance_detected) admin_events = sum(1 for e in self.audit_events if e.admin_access) - + return { - 'total_access_events': total_events, - 'reconnaissance_detected_events': reconnaissance_events, - 'admin_access_events': admin_events, - 'reconnaissance_rate': reconnaissance_events / max(total_events, 1), - 'last_24h_events': sum( - 1 for e in self.audit_events - if e.timestamp > datetime.utcnow() - timedelta(days=1) - ) + "total_access_events": total_events, + "reconnaissance_detected_events": reconnaissance_events, + "admin_access_events": admin_events, + "reconnaissance_rate": reconnaissance_events / max(total_events, 1), + "last_24h_events": sum( + 1 for e in self.audit_events if e.timestamp > datetime.utcnow() - timedelta(days=1) + ), } # Global instance for dependency injection _system_sanitization_service = None + def get_system_info_sanitization_service() -> SystemInfoSanitizationService: """Get or create the global system information sanitization service""" global _system_sanitization_service if _system_sanitization_service is None: _system_sanitization_service = SystemInfoSanitizationService() - return _system_sanitization_service \ No newline at end of file + return _system_sanitization_service diff --git a/backend/app/services/terminal_service.py b/backend/app/services/terminal_service.py index b2d3ac9a..11c42c14 100644 --- a/backend/app/services/terminal_service.py +++ b/backend/app/services/terminal_service.py @@ -26,7 +26,7 @@ class SSHTerminalSession: """ Manages an SSH terminal session with WebSocket communication """ - + def __init__(self, websocket: WebSocket, host: Host, db: Session): self.websocket = websocket self.host = host @@ -35,11 +35,11 @@ def __init__(self, websocket: WebSocket, host: Host, db: Session): self.ssh_channel: Optional[paramiko.Channel] = None self.is_connected = False self.tasks: Dict[str, asyncio.Task] = {} - + async def connect(self) -> bool: """ Establish SSH connection to the host - + Returns: bool: True if connection successful, False otherwise """ @@ -51,94 +51,101 @@ async def connect(self) -> bool: # Load system and user host keys for validation try: self.ssh_client.load_system_host_keys() - self.ssh_client.load_host_keys('/home/openwatch/.ssh/known_hosts') + self.ssh_client.load_host_keys("/home/openwatch/.ssh/known_hosts") except FileNotFoundError: - logger.warning("No known_hosts files found - SSH connections may fail without proper host key management") - + logger.warning( + "No known_hosts files found - SSH connections may fail without proper host key management" + ) + # Get host credentials auth_method, credentials = await self._get_host_credentials() if not auth_method: await self._send_error("No authentication method configured for host") return False - + # Connect to host connect_kwargs = { - 'hostname': self.host.ip_address, - 'port': self.host.port or 22, - 'username': credentials.get('username', 'root'), - 'timeout': 10, - 'allow_agent': False, - 'look_for_keys': False + "hostname": self.host.ip_address, + "port": self.host.port or 22, + "username": credentials.get("username", "root"), + "timeout": 10, + "allow_agent": False, + "look_for_keys": False, } - - if auth_method == 'password': - if not credentials.get('password'): + + if auth_method == "password": + if not credentials.get("password"): await self._send_error("Password not configured for host") return False - connect_kwargs['password'] = credentials['password'] - - elif auth_method in ['ssh_key', 'system_default']: - if not credentials.get('private_key'): + connect_kwargs["password"] = credentials["password"] + + elif auth_method in ["ssh_key", "system_default"]: + if not credentials.get("private_key"): await self._send_error("SSH key not configured for host") return False - + # Validate SSH key first - validation_result = validate_ssh_key(credentials['private_key']) + validation_result = validate_ssh_key(credentials["private_key"]) if not validation_result.is_valid: await self._send_error(f"Invalid SSH key: {validation_result.error_message}") return False - + # Load private key try: from io import StringIO - key_io = StringIO(credentials['private_key']) - + + key_io = StringIO(credentials["private_key"]) + # Try different key types private_key = None - for key_class in [paramiko.RSAKey, paramiko.Ed25519Key, paramiko.ECDSAKey, paramiko.DSSKey]: + for key_class in [ + paramiko.RSAKey, + paramiko.Ed25519Key, + paramiko.ECDSAKey, + paramiko.DSSKey, + ]: try: key_io.seek(0) private_key = key_class.from_private_key( - key_io, - password=credentials.get('passphrase') + key_io, password=credentials.get("passphrase") ) break except Exception: continue - + if not private_key: await self._send_error("Could not load SSH private key") return False - - connect_kwargs['pkey'] = private_key - + + connect_kwargs["pkey"] = private_key + except Exception as e: await self._send_error(f"SSH key loading failed: {str(e)}") return False else: await self._send_error(f"Unsupported authentication method: {auth_method}") return False - + # Attempt SSH connection - logger.info(f"Connecting to {self.host.hostname} ({self.host.ip_address}:{self.host.port})") + logger.info( + f"Connecting to {self.host.hostname} ({self.host.ip_address}:{self.host.port})" + ) self.ssh_client.connect(**connect_kwargs) - + # Create interactive shell channel self.ssh_channel = self.ssh_client.invoke_shell( - term='xterm-256color', - width=80, - height=24 + term="xterm-256color", width=80, height=24 ) - + self.is_connected = True logger.info(f"SSH connection established to {self.host.hostname}") - + # Start background tasks for data transfer - self.tasks['ssh_to_ws'] = asyncio.create_task(self._ssh_to_websocket()) - self.tasks['ws_to_ssh'] = asyncio.create_task(self._websocket_to_ssh()) - + self.tasks["ssh_to_ws"] = asyncio.create_task(self._ssh_to_websocket()) + self.tasks["ws_to_ssh"] = asyncio.create_task(self._websocket_to_ssh()) + return True - + except paramiko.AuthenticationException: await self._send_error("SSH authentication failed - invalid credentials") return False @@ -149,90 +156,96 @@ async def connect(self) -> bool: logger.error(f"SSH connection failed: {e}") await self._send_error(f"Connection failed: {str(e)}") return False - + async def _get_host_credentials(self) -> tuple[Optional[str], Dict[str, str]]: """ Get host authentication credentials - + Returns: Tuple of (auth_method, credentials_dict) """ try: - auth_method = self.host.auth_method or 'system_default' + auth_method = self.host.auth_method or "system_default" credentials = {} - - logger.info(f"Getting credentials for host {self.host.hostname} with auth_method: {auth_method}") - + + logger.info( + f"Getting credentials for host {self.host.hostname} with auth_method: {auth_method}" + ) + # Use centralized authentication service instead of old dual system try: from ..services.auth_service import get_auth_service + auth_service = get_auth_service(self.db) - + # Determine if we should use default credentials or host-specific - use_default = auth_method in ['default', 'system_default'] + use_default = auth_method in ["default", "system_default"] target_id = None if use_default else str(self.host.id) - + # Resolve credentials using centralized service credential_data = auth_service.resolve_credential( - target_id=target_id, - use_default=use_default + target_id=target_id, use_default=use_default ) - + if credential_data: credentials = { - 'username': credential_data.username, - 'private_key': credential_data.private_key, - 'password': credential_data.password, - 'private_key_passphrase': credential_data.private_key_passphrase + "username": credential_data.username, + "private_key": credential_data.private_key, + "password": credential_data.password, + "private_key_passphrase": credential_data.private_key_passphrase, } - logger.info(f"Successfully resolved {credential_data.source} credentials for terminal service") + logger.info( + f"Successfully resolved {credential_data.source} credentials for terminal service" + ) else: logger.warning("No credentials available via centralized auth service") - + except Exception as e: logger.error(f"Failed to resolve credentials via centralized service: {e}") # Fallback to system default if centralized service fails - if auth_method == 'system_default': + if auth_method == "system_default": try: - with open('/home/rracine/hanalyx/rsa_private_key', 'r') as f: - credentials['private_key'] = f.read() - credentials['username'] = 'root' + with open("/home/rracine/hanalyx/rsa_private_key", "r") as f: + credentials["private_key"] = f.read() + credentials["username"] = "root" logger.info("Using fallback system default SSH key") except FileNotFoundError: logger.error("System default SSH key not found") return None, {} - + # If we still don't have credentials and this is a password auth host, try test credentials - if not credentials and auth_method == 'password': + if not credentials and auth_method == "password": # For test hosts with known credentials (temporary workaround) test_hosts = { - '146.190.45.61': {'username': 'root', 'password': 'DRUCrItroS7I@E3iv&CR'}, - '146.190.156.198': {'username': 'root', 'password': 'DRUCrItroS7I@E3iv&CR'} + "146.190.45.61": {"username": "root", "password": "DRUCrItroS7I@E3iv&CR"}, + "146.190.156.198": {"username": "root", "password": "DRUCrItroS7I@E3iv&CR"}, } - + if self.host.ip_address in test_hosts: logger.info(f"Using test credentials for host {self.host.ip_address}") credentials = test_hosts[self.host.ip_address] else: logger.warning(f"No credentials available for host {self.host.hostname}") return None, {} - + # If we still don't have credentials, fail if not credentials: logger.error(f"No credentials found for host {self.host.hostname}") return None, {} - + # Set default username if not provided - if 'username' not in credentials: - credentials['username'] = self.host.username or 'root' - - logger.info(f"Returning auth_method: {auth_method}, credentials keys: {list(credentials.keys())}") + if "username" not in credentials: + credentials["username"] = self.host.username or "root" + + logger.info( + f"Returning auth_method: {auth_method}, credentials keys: {list(credentials.keys())}" + ) return auth_method, credentials - + except Exception as e: logger.error(f"Error getting host credentials: {e}") return None, {} - + async def _ssh_to_websocket(self): """ Transfer data from SSH channel to WebSocket @@ -248,7 +261,7 @@ async def _ssh_to_websocket(self): except Exception as e: logger.error(f"SSH to WebSocket transfer error: {e}") await self._send_error("SSH session terminated unexpectedly") - + async def _websocket_to_ssh(self): """ Transfer data from WebSocket to SSH channel @@ -258,28 +271,28 @@ async def _websocket_to_ssh(self): try: # Receive data from WebSocket data = await self.websocket.receive() - - if data.get('type') == 'websocket.receive': - if 'bytes' in data: + + if data.get("type") == "websocket.receive": + if "bytes" in data: # Binary data (terminal input) if self.ssh_channel: - self.ssh_channel.send(data['bytes']) - elif 'text' in data: + self.ssh_channel.send(data["bytes"]) + elif "text" in data: # Text data (terminal input) if self.ssh_channel: - self.ssh_channel.send(data['text'].encode('utf-8')) - elif data.get('type') == 'websocket.disconnect': + self.ssh_channel.send(data["text"].encode("utf-8")) + elif data.get("type") == "websocket.disconnect": break - + except WebSocketDisconnect: break except Exception as e: logger.error(f"WebSocket to SSH transfer error: {e}") break - + except Exception as e: logger.error(f"WebSocket to SSH handler error: {e}") - + async def _send_error(self, message: str): """ Send error message to WebSocket client @@ -288,7 +301,7 @@ async def _send_error(self, message: str): await self.websocket.send_text(f"ERROR: {message}") except Exception: pass - + async def resize_terminal(self, cols: int, rows: int): """ Resize the SSH terminal @@ -298,13 +311,13 @@ async def resize_terminal(self, cols: int, rows: int): self.ssh_channel.resize_pty(width=cols, height=rows) except Exception as e: logger.error(f"Terminal resize error: {e}") - + async def disconnect(self): """ Close SSH connection and cleanup resources """ self.is_connected = False - + # Cancel background tasks for task_name, task in self.tasks.items(): if not task.done(): @@ -313,7 +326,7 @@ async def disconnect(self): await task except asyncio.CancelledError: pass - + # Close SSH resources if self.ssh_channel: try: @@ -321,14 +334,14 @@ async def disconnect(self): except Exception: pass self.ssh_channel = None - + if self.ssh_client: try: self.ssh_client.close() except Exception: pass self.ssh_client = None - + logger.info(f"SSH session to {self.host.hostname} closed") @@ -336,20 +349,16 @@ class TerminalService: """ Service for managing WebSocket terminal connections """ - + def __init__(self): self.active_sessions: Dict[str, SSHTerminalSession] = {} - + async def handle_websocket_connection( - self, - websocket: WebSocket, - host_id: str, - db: Session, - client_ip: str + self, websocket: WebSocket, host_id: str, db: Session, client_ip: str ): """ Handle new WebSocket terminal connection - + Args: websocket: WebSocket connection host_id: Host ID for terminal session @@ -357,20 +366,22 @@ async def handle_websocket_connection( client_ip: Client IP address for audit logging """ session_key = f"{host_id}_{id(websocket)}" - + try: # Accept WebSocket connection await websocket.accept() - + # Get host information using raw SQL - result = db.execute(text("SELECT * FROM hosts WHERE id = :host_id"), {"host_id": host_id}) + result = db.execute( + text("SELECT * FROM hosts WHERE id = :host_id"), {"host_id": host_id} + ) host_data = result.fetchone() - + if not host_data: await websocket.send_text("ERROR: Host not found") await websocket.close() return - + # Create a simple host object with the required attributes class SimpleHost: def __init__(self, row): @@ -381,33 +392,33 @@ def __init__(self, row): self.username = row.username self.auth_method = row.auth_method # NOTE: encrypted_credentials removed - using centralized auth service - + host = SimpleHost(host_data) - + # Log terminal access attempt await log_security_event( db=db, event_type="TERMINAL_ACCESS", ip_address=client_ip, - details=f"Terminal access requested for host {host.hostname} ({host.ip_address})" + details=f"Terminal access requested for host {host.hostname} ({host.ip_address})", ) - + # Create terminal session session = SSHTerminalSession(websocket, host, db) self.active_sessions[session_key] = session - + # Attempt SSH connection connection_success = await session.connect() - + if connection_success: # Log successful connection await log_security_event( db=db, event_type="TERMINAL_CONNECTED", ip_address=client_ip, - details=f"SSH terminal connected to {host.hostname} ({host.ip_address})" + details=f"SSH terminal connected to {host.hostname} ({host.ip_address})", ) - + # Keep connection alive until WebSocket closes try: while True: @@ -422,9 +433,9 @@ def __init__(self, row): db=db, event_type="TERMINAL_FAILED", ip_address=client_ip, - details=f"SSH terminal connection failed for {host.hostname} ({host.ip_address})" + details=f"SSH terminal connection failed for {host.hostname} ({host.ip_address})", ) - + except WebSocketDisconnect: logger.info(f"WebSocket disconnected for host {host_id}") except Exception as e: @@ -438,7 +449,7 @@ def __init__(self, row): if session_key in self.active_sessions: await self.active_sessions[session_key].disconnect() del self.active_sessions[session_key] - + try: await websocket.close() except Exception: @@ -446,4 +457,4 @@ def __init__(self, row): # Global terminal service instance -terminal_service = TerminalService() \ No newline at end of file +terminal_service = TerminalService() diff --git a/backend/app/services/tracing.py b/backend/app/services/tracing.py index 614f0e36..288057d5 100644 --- a/backend/app/services/tracing.py +++ b/backend/app/services/tracing.py @@ -3,6 +3,7 @@ Comprehensive tracing for request flows and service integration Author: Noah Chen - nc9010@hanalyx.com """ + import os import logging from typing import Optional @@ -23,65 +24,65 @@ class TracingConfig: """OpenTelemetry tracing configuration and setup""" - + def __init__( self, service_name: str = "openwatch", service_version: str = "1.0.0", environment: str = "production", jaeger_endpoint: Optional[str] = None, - otlp_endpoint: Optional[str] = None + otlp_endpoint: Optional[str] = None, ): self.service_name = service_name self.service_version = service_version self.environment = environment self.jaeger_endpoint = jaeger_endpoint or os.getenv( - "JAEGER_ENDPOINT", - "http://secureops-jaeger:14268/api/traces" + "JAEGER_ENDPOINT", "http://secureops-jaeger:14268/api/traces" ) self.otlp_endpoint = otlp_endpoint or os.getenv( - "OTLP_ENDPOINT", - "http://secureops-jaeger:14250" + "OTLP_ENDPOINT", "http://secureops-jaeger:14250" ) - + self.tracer_provider = None self.tracer = None - + def initialize_tracing(self): """Initialize OpenTelemetry tracing""" try: # Create resource with service information - resource = Resource.create({ - "service.name": self.service_name, - "service.version": self.service_version, - "service.environment": self.environment, - "service.instance.id": os.getenv("HOSTNAME", "unknown"), - }) - + resource = Resource.create( + { + "service.name": self.service_name, + "service.version": self.service_version, + "service.environment": self.environment, + "service.instance.id": os.getenv("HOSTNAME", "unknown"), + } + ) + # Create tracer provider self.tracer_provider = TracerProvider(resource=resource) trace.set_tracer_provider(self.tracer_provider) - + # Configure exporters self._setup_exporters() - + # Get tracer self.tracer = trace.get_tracer(__name__) - + # Instrument libraries self._instrument_libraries() - + logger.info(f"OpenTelemetry tracing initialized for {self.service_name}") return True - + except Exception as e: logger.error(f"Failed to initialize tracing: {e}") return False - + def _setup_exporters(self): """Setup trace exporters (Jaeger, OTLP, Console)""" exporters = [] - + # Jaeger exporter try: jaeger_exporter = JaegerExporter( @@ -93,71 +94,67 @@ def _setup_exporters(self): logger.info("Jaeger exporter configured") except Exception as e: logger.warning(f"Failed to configure Jaeger exporter: {e}") - + # OTLP exporter try: otlp_exporter = OTLPSpanExporter( - endpoint=self.otlp_endpoint, - insecure=True # Use TLS in production + endpoint=self.otlp_endpoint, insecure=True # Use TLS in production ) exporters.append(otlp_exporter) logger.info("OTLP exporter configured") except Exception as e: logger.warning(f"Failed to configure OTLP exporter: {e}") - + # Console exporter for development if os.getenv("OPENWATCH_DEBUG", "false").lower() == "true": console_exporter = ConsoleSpanExporter() exporters.append(console_exporter) logger.info("Console exporter configured for development") - + # Add exporters to tracer provider for exporter in exporters: span_processor = BatchSpanProcessor(exporter) self.tracer_provider.add_span_processor(span_processor) - + def _instrument_libraries(self): """Instrument common libraries for automatic tracing""" try: # Instrument HTTP requests RequestsInstrumentor().instrument() HTTPXClientInstrumentor().instrument() - + # Instrument Redis (if available) try: RedisInstrumentor().instrument() logger.info("Redis instrumentation enabled") except Exception as e: logger.warning(f"Redis instrumentation failed: {e}") - + logger.info("Library instrumentation completed") - + except Exception as e: logger.error(f"Library instrumentation failed: {e}") - + def instrument_fastapi(self, app): """Instrument FastAPI application""" try: FastAPIInstrumentor.instrument_app( app, tracer_provider=self.tracer_provider, - excluded_urls="/health,/metrics" # Exclude health checks from tracing + excluded_urls="/health,/metrics", # Exclude health checks from tracing ) logger.info("FastAPI instrumentation enabled") except Exception as e: logger.error(f"FastAPI instrumentation failed: {e}") - + def instrument_sqlalchemy(self, engine): """Instrument SQLAlchemy for database tracing""" try: - SQLAlchemyInstrumentor().instrument( - engine=engine, - tracer_provider=self.tracer_provider - ) + SQLAlchemyInstrumentor().instrument(engine=engine, tracer_provider=self.tracer_provider) logger.info("SQLAlchemy instrumentation enabled") except Exception as e: logger.error(f"SQLAlchemy instrumentation failed: {e}") - + def get_tracer(self): """Get the configured tracer""" return self.tracer @@ -165,10 +162,10 @@ def get_tracer(self): class SecureOpsTracer: """Custom tracer for SecureOps-specific operations""" - + def __init__(self, tracer): self.tracer = tracer - + def trace_scan_operation(self, scan_id: str, host_id: str, profile: str): """Create span for SCAP scan operation""" return self.tracer.start_span( @@ -177,10 +174,10 @@ def trace_scan_operation(self, scan_id: str, host_id: str, profile: str): "scan.id": scan_id, "scan.host_id": host_id, "scan.profile": profile, - "operation.type": "scap_scan" - } + "operation.type": "scap_scan", + }, ) - + def trace_remediation_call(self, host_id: str, rule_count: int): """Create span for AEGIS remediation call""" return self.tracer.start_span( @@ -188,10 +185,10 @@ def trace_remediation_call(self, host_id: str, rule_count: int): attributes={ "remediation.host_id": host_id, "remediation.rule_count": rule_count, - "operation.type": "remediation" - } + "operation.type": "remediation", + }, ) - + def trace_integration_call(self, target_service: str, endpoint: str): """Create span for external service integration""" return self.tracer.start_span( @@ -199,21 +196,17 @@ def trace_integration_call(self, target_service: str, endpoint: str): attributes={ "integration.target": target_service, "integration.endpoint": endpoint, - "operation.type": "integration" - } + "operation.type": "integration", + }, ) - + def trace_database_operation(self, operation: str, table: str): """Create span for database operations""" return self.tracer.start_span( f"db_{operation}", - attributes={ - "db.operation": operation, - "db.table": table, - "operation.type": "database" - } + attributes={"db.operation": operation, "db.table": table, "operation.type": "database"}, ) - + def trace_authentication(self, username: str, method: str): """Create span for authentication operations""" return self.tracer.start_span( @@ -221,10 +214,10 @@ def trace_authentication(self, username: str, method: str): attributes={ "auth.username": username, "auth.method": method, - "operation.type": "authentication" - } + "operation.type": "authentication", + }, ) - + def trace_workflow(self, workflow_type: str, workflow_id: str): """Create span for end-to-end workflows""" return self.tracer.start_span( @@ -232,30 +225,29 @@ def trace_workflow(self, workflow_type: str, workflow_id: str): attributes={ "workflow.type": workflow_type, "workflow.id": workflow_id, - "operation.type": "workflow" - } + "operation.type": "workflow", + }, ) - + def add_scan_result_attributes(self, span, scan_result): """Add scan result attributes to span""" if span and scan_result: - span.set_attributes({ - "scan.rules_total": scan_result.get("rules_total", 0), - "scan.rules_passed": scan_result.get("rules_passed", 0), - "scan.rules_failed": scan_result.get("rules_failed", 0), - "scan.compliance_score": scan_result.get("compliance_score", 0), - "scan.duration_seconds": scan_result.get("duration", 0) - }) - + span.set_attributes( + { + "scan.rules_total": scan_result.get("rules_total", 0), + "scan.rules_passed": scan_result.get("rules_passed", 0), + "scan.rules_failed": scan_result.get("rules_failed", 0), + "scan.compliance_score": scan_result.get("compliance_score", 0), + "scan.duration_seconds": scan_result.get("duration", 0), + } + ) + def add_error_attributes(self, span, error: Exception): """Add error attributes to span""" if span and error: span.set_status(trace.Status(trace.StatusCode.ERROR)) span.record_exception(error) - span.set_attributes({ - "error.type": type(error).__name__, - "error.message": str(error) - }) + span.set_attributes({"error.type": type(error).__name__, "error.message": str(error)}) # Global tracing configuration @@ -264,24 +256,20 @@ def add_error_attributes(self, span, error: Exception): def initialize_tracing( - service_name: str = "openwatch", - service_version: str = "1.0.0", - environment: str = "production" + service_name: str = "openwatch", service_version: str = "1.0.0", environment: str = "production" ) -> bool: """Initialize global tracing configuration""" global _tracing_config, _secureops_tracer - + _tracing_config = TracingConfig( - service_name=service_name, - service_version=service_version, - environment=environment + service_name=service_name, service_version=service_version, environment=environment ) - + success = _tracing_config.initialize_tracing() - + if success: _secureops_tracer = SecureOpsTracer(_tracing_config.get_tracer()) - + return success @@ -304,4 +292,4 @@ def instrument_fastapi_app(app): def instrument_database_engine(engine): """Instrument database engine with tracing""" if _tracing_config: - _tracing_config.instrument_sqlalchemy(engine) \ No newline at end of file + _tracing_config.instrument_sqlalchemy(engine) diff --git a/backend/app/services/unified_validation_service.py b/backend/app/services/unified_validation_service.py index b8d68712..320d1144 100644 --- a/backend/app/services/unified_validation_service.py +++ b/backend/app/services/unified_validation_service.py @@ -3,6 +3,7 @@ Consolidates all credential types into a single, reliable validation flow. Eliminates duplication between system default and host-based credential validation. """ + import logging import time import asyncio @@ -18,9 +19,14 @@ from .scap_scanner import SCAPScanner from .system_info_sanitization import sanitize_system_info from ..models.error_models import ( - ValidationResultInternal, ValidationResultResponse, - ScanErrorInternal, ScanErrorResponse, - ErrorCategory, ErrorSeverity, AutomatedFix, AutomatedFixResponse + ValidationResultInternal, + ValidationResultResponse, + ScanErrorInternal, + ScanErrorResponse, + ErrorCategory, + ErrorSeverity, + AutomatedFix, + AutomatedFixResponse, ) logger = logging.getLogger(__name__) @@ -28,6 +34,7 @@ class ValidationRequest(BaseModel): """Unified validation request model""" + host_id: str use_system_default: bool = False target_hostname: str @@ -41,26 +48,24 @@ class UnifiedValidationService: Unified validation service that handles all credential types consistently. Eliminates the duplication between host-based and system default validation. """ - + def __init__(self, db: Session): self.db = db self.auth_service = CentralizedAuthService(db) self.error_classifier = ErrorClassificationService() self.sanitization_service = get_error_sanitization_service() self.scap_scanner = SCAPScanner() - + async def validate_scan_prerequisites( - self, - request: ValidationRequest, - current_user: dict + self, request: ValidationRequest, current_user: dict ) -> Tuple[ValidationResultInternal, ValidationResultResponse]: """ Unified pre-scan validation that works with any credential type. - + Args: request: Validation request parameters current_user: Current authenticated user - + Returns: Tuple of (internal_result, sanitized_response) """ @@ -69,108 +74,96 @@ async def validate_scan_prerequisites( errors = [] warnings = [] system_info = {} - + try: logger.info(f"Starting unified validation for host {request.host_id}") - + # Step 1: Resolve credentials through unified auth service credential_data = await self._resolve_credentials(request) validation_checks["credential_resolution"] = True - + # Step 2: Network connectivity test network_result = await self._test_network_connectivity( - request.target_hostname, - request.target_port + request.target_hostname, request.target_port ) validation_checks["network_connectivity"] = network_result["success"] - + if not network_result["success"]: errors.append(self._create_network_error(network_result["error"])) - + # Step 3: SSH Authentication test (map to "authentication" for frontend compatibility) if validation_checks["network_connectivity"]: auth_result = await self._test_ssh_authentication( - request.target_hostname, - request.target_port, - credential_data + request.target_hostname, request.target_port, credential_data ) validation_checks["authentication"] = auth_result["success"] - + if auth_result["success"]: system_info = auth_result.get("system_info", {}) else: errors.append(self._create_auth_error(auth_result["error"])) - + # Step 4: System privileges check (map to "privileges" for frontend compatibility) if validation_checks.get("authentication", False): privilege_result = await self._test_system_privileges( - request.target_hostname, - request.target_port, - credential_data + request.target_hostname, request.target_port, credential_data ) validation_checks["privileges"] = privilege_result["success"] - + if not privilege_result["success"]: if privilege_result["severity"] == "error": errors.append(self._create_privilege_error(privilege_result["error"])) else: warnings.append(self._create_privilege_warning(privilege_result["error"])) - - # Step 5: System resources check (map to "resources" for frontend compatibility) + + # Step 5: System resources check (map to "resources" for frontend compatibility) if validation_checks.get("authentication", False): resource_result = await self._test_system_resources( - request.target_hostname, - request.target_port, - credential_data + request.target_hostname, request.target_port, credential_data ) validation_checks["resources"] = resource_result["success"] - + if not resource_result["success"]: warnings.append(self._create_resource_warning(resource_result["error"])) - + # Step 6: OpenSCAP dependencies check (map to "dependencies" for frontend compatibility) if validation_checks.get("authentication", False): scap_result = await self._test_openscap_dependencies( - request.target_hostname, - request.target_port, - credential_data + request.target_hostname, request.target_port, credential_data ) validation_checks["dependencies"] = scap_result["success"] - + if not scap_result["success"]: warnings.append(self._create_dependency_warning(scap_result["error"])) - + except Exception as e: logger.error(f"Unexpected error during validation: {e}", exc_info=True) errors.append(self._create_unexpected_error(str(e))) validation_checks["unexpected_error"] = True - + # Create internal result with full details duration = time.time() - start_time can_proceed = len(errors) == 0 - + internal_result = ValidationResultInternal( can_proceed=can_proceed, errors=errors, warnings=warnings, pre_flight_duration=duration, system_info=system_info, - validation_checks=validation_checks + validation_checks=validation_checks, ) - + # Create sanitized response for frontend - sanitized_response = await self._sanitize_validation_result( - internal_result, - current_user - ) - + sanitized_response = await self._sanitize_validation_result(internal_result, current_user) + logger.info( f"Validation completed for host {request.host_id}: " f"can_proceed={can_proceed}, errors={len(errors)}, warnings={len(warnings)}" ) - + return internal_result, sanitized_response - + async def _resolve_credentials(self, request: ValidationRequest) -> CredentialData: """Resolve credentials using unified auth service""" try: @@ -181,34 +174,26 @@ async def _resolve_credentials(self, request: ValidationRequest) -> CredentialDa except Exception as e: logger.error(f"Credential resolution failed: {e}") raise ValueError(f"Failed to resolve credentials: {str(e)}") - + async def _test_network_connectivity(self, hostname: str, port: int) -> Dict: """Test basic network connectivity""" try: import socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(10) result = sock.connect_ex((hostname, port)) sock.close() - + if result == 0: return {"success": True} else: - return { - "success": False, - "error": f"Cannot connect to {hostname}:{port}" - } + return {"success": False, "error": f"Cannot connect to {hostname}:{port}"} except Exception as e: - return { - "success": False, - "error": f"Network connectivity test failed: {str(e)}" - } - + return {"success": False, "error": f"Network connectivity test failed: {str(e)}"} + async def _test_ssh_authentication( - self, - hostname: str, - port: int, - credential_data: CredentialData + self, hostname: str, port: int, credential_data: CredentialData ) -> Dict: """Test SSH authentication using unified credentials""" try: @@ -218,26 +203,20 @@ async def _test_ssh_authentication( port=port, username=credential_data.username, auth_method=credential_data.auth_method.value, - credential=credential_data.private_key or credential_data.password or "" + credential=credential_data.private_key or credential_data.password or "", ) - + return { "success": result.get("connection_status") == "success", "system_info": result.get("system_info", {}), - "error": result.get("error", "Authentication failed") + "error": result.get("error", "Authentication failed"), } except Exception as e: logger.error(f"SSH authentication test failed: {e}") - return { - "success": False, - "error": f"SSH authentication failed: {str(e)}" - } - + return {"success": False, "error": f"SSH authentication failed: {str(e)}"} + async def _test_system_privileges( - self, - hostname: str, - port: int, - credential_data: CredentialData + self, hostname: str, port: int, credential_data: CredentialData ) -> Dict: """Test system privileges (sudo/root access)""" try: @@ -249,47 +228,35 @@ async def _test_system_privileges( return { "success": False, "severity": "warning", - "error": "Non-root user detected. Some scans may require elevated privileges." + "error": "Non-root user detected. Some scans may require elevated privileges.", } except Exception as e: return { "success": False, "severity": "error", - "error": f"Privilege test failed: {str(e)}" + "error": f"Privilege test failed: {str(e)}", } - + async def _test_system_resources( - self, - hostname: str, - port: int, - credential_data: CredentialData + self, hostname: str, port: int, credential_data: CredentialData ) -> Dict: """Test system resources (disk space, memory)""" try: # Basic resource check - would normally test disk space, etc. return {"success": True} except Exception as e: - return { - "success": False, - "error": f"Resource check failed: {str(e)}" - } - + return {"success": False, "error": f"Resource check failed: {str(e)}"} + async def _test_openscap_dependencies( - self, - hostname: str, - port: int, - credential_data: CredentialData + self, hostname: str, port: int, credential_data: CredentialData ) -> Dict: """Test OpenSCAP tool availability""" try: # This would test for oscap command availability return {"success": True} except Exception as e: - return { - "success": False, - "error": f"OpenSCAP dependency check failed: {str(e)}" - } - + return {"success": False, "error": f"OpenSCAP dependency check failed: {str(e)}"} + def _create_network_error(self, error_msg: str) -> ScanErrorInternal: """Create network connectivity error""" return ScanErrorInternal( @@ -300,9 +267,9 @@ def _create_network_error(self, error_msg: str) -> ScanErrorInternal: technical_details={"error": error_msg}, user_guidance="Check network connectivity and firewall settings", can_retry=True, - retry_after=30 + retry_after=30, ) - + def _create_auth_error(self, error_msg: str) -> ScanErrorInternal: """Create authentication error""" return ScanErrorInternal( @@ -313,9 +280,9 @@ def _create_auth_error(self, error_msg: str) -> ScanErrorInternal: technical_details={"error": error_msg}, user_guidance="Verify credentials and SSH key permissions", can_retry=True, - retry_after=60 + retry_after=60, ) - + def _create_privilege_error(self, error_msg: str) -> ScanErrorInternal: """Create privilege error""" return ScanErrorInternal( @@ -325,9 +292,9 @@ def _create_privilege_error(self, error_msg: str) -> ScanErrorInternal: message="Insufficient system privileges", technical_details={"error": error_msg}, user_guidance="Ensure user has sudo access or use root account", - can_retry=False + can_retry=False, ) - + def _create_privilege_warning(self, error_msg: str) -> ScanErrorInternal: """Create privilege warning""" return ScanErrorInternal( @@ -337,9 +304,9 @@ def _create_privilege_warning(self, error_msg: str) -> ScanErrorInternal: message="Limited system privileges detected", technical_details={"error": error_msg}, user_guidance="Some scans may require elevated privileges", - can_retry=False + can_retry=False, ) - + def _create_resource_warning(self, error_msg: str) -> ScanErrorInternal: """Create resource warning""" return ScanErrorInternal( @@ -349,9 +316,9 @@ def _create_resource_warning(self, error_msg: str) -> ScanErrorInternal: message="System resource constraints detected", technical_details={"error": error_msg}, user_guidance="Monitor system resources during scan execution", - can_retry=False + can_retry=False, ) - + def _create_dependency_warning(self, error_msg: str) -> ScanErrorInternal: """Create dependency warning""" return ScanErrorInternal( @@ -361,9 +328,9 @@ def _create_dependency_warning(self, error_msg: str) -> ScanErrorInternal: message="OpenSCAP dependencies may be missing", technical_details={"error": error_msg}, user_guidance="Install OpenSCAP tools on target system if needed", - can_retry=False + can_retry=False, ) - + def _create_unexpected_error(self, error_msg: str) -> ScanErrorInternal: """Create unexpected error""" return ScanErrorInternal( @@ -374,21 +341,19 @@ def _create_unexpected_error(self, error_msg: str) -> ScanErrorInternal: technical_details={"error": error_msg}, user_guidance="Contact support if this error persists", can_retry=True, - retry_after=120 + retry_after=120, ) - + async def _sanitize_validation_result( - self, - internal_result: ValidationResultInternal, - current_user: dict + self, internal_result: ValidationResultInternal, current_user: dict ) -> ValidationResultResponse: """Convert internal result to sanitized response""" sanitized_errors = [] sanitized_warnings = [] - + # Get client info for sanitization user_id = current_user.get("sub") if current_user else None - + # Sanitize errors for error in internal_result.errors: # Convert to ScanErrorResponse (sanitized version) @@ -400,16 +365,17 @@ async def _sanitize_validation_result( user_guidance=error.user_guidance, automated_fixes=[ # Convert AutomatedFix to AutomatedFixResponse (remove sensitive fields) - self._sanitize_automated_fix(fix) for fix in error.automated_fixes + self._sanitize_automated_fix(fix) + for fix in error.automated_fixes ], can_retry=error.can_retry, retry_after=error.retry_after, documentation_url=error.documentation_url, - timestamp=error.timestamp + timestamp=error.timestamp, ) sanitized_errors.append(sanitized_error) - - # Sanitize warnings + + # Sanitize warnings for warning in internal_result.warnings: sanitized_warning = ScanErrorResponse( error_code=warning.error_code, @@ -423,22 +389,22 @@ async def _sanitize_validation_result( can_retry=warning.can_retry, retry_after=warning.retry_after, documentation_url=warning.documentation_url, - timestamp=warning.timestamp + timestamp=warning.timestamp, ) sanitized_warnings.append(sanitized_warning) - + # Sanitize system info sanitized_system_info = sanitize_system_info(internal_result.system_info) - + return ValidationResultResponse( can_proceed=internal_result.can_proceed, errors=sanitized_errors, warnings=sanitized_warnings, pre_flight_duration=internal_result.pre_flight_duration, system_info=sanitized_system_info, - validation_checks=internal_result.validation_checks + validation_checks=internal_result.validation_checks, ) - + def _sanitize_automated_fix(self, fix: AutomatedFix) -> AutomatedFixResponse: """Convert AutomatedFix to sanitized AutomatedFixResponse""" return AutomatedFixResponse( @@ -446,11 +412,11 @@ def _sanitize_automated_fix(self, fix: AutomatedFix) -> AutomatedFixResponse: description=fix.description, requires_sudo=fix.requires_sudo, estimated_time=fix.estimated_time, - is_safe=fix.is_safe + is_safe=fix.is_safe, # Note: command and rollback_command are omitted for security ) def get_unified_validation_service(db: Session) -> UnifiedValidationService: """Factory function to get unified validation service instance""" - return UnifiedValidationService(db) \ No newline at end of file + return UnifiedValidationService(db) diff --git a/backend/app/services/webhook_security.py b/backend/app/services/webhook_security.py index 41f5ea93..9f474faf 100644 --- a/backend/app/services/webhook_security.py +++ b/backend/app/services/webhook_security.py @@ -3,6 +3,7 @@ HMAC-SHA256 signature generation and verification for webhooks Compatible with AEGIS webhook security implementation """ + import hmac import hashlib import json @@ -16,31 +17,31 @@ class WebhookSecurity: """HMAC-SHA256 webhook signature generation and verification""" - + def __init__(self, secret: Optional[str] = None): """ Initialize webhook security with shared secret - + Args: secret: Webhook secret key. If None, uses settings """ self.settings = get_settings() - self.secret = secret or getattr(self.settings, 'webhook_secret', None) - + self.secret = secret or getattr(self.settings, "webhook_secret", None) + if not self.secret and not self.settings.debug: logger.error("Webhook secret not configured") raise ValueError("Webhook secret is required for signature generation") - + def generate_signature(self, payload: Union[Dict[str, Any], str, bytes]) -> str: """ Generate HMAC-SHA256 signature for webhook payload - + Args: payload: Webhook payload (dict, string, or bytes) - + Returns: Signature string in format "sha256=" - + Raises: ValueError: If secret is not configured """ @@ -49,41 +50,35 @@ def generate_signature(self, payload: Union[Dict[str, Any], str, bytes]) -> str: logger.warning("Webhook signature generation skipped - no secret in debug mode") return "sha256=debug-signature" raise ValueError("Webhook secret not configured") - + # Normalize payload to bytes if isinstance(payload, dict): - message = json.dumps(payload, separators=(',', ':'), sort_keys=True).encode('utf-8') + message = json.dumps(payload, separators=(",", ":"), sort_keys=True).encode("utf-8") elif isinstance(payload, str): - message = payload.encode('utf-8') + message = payload.encode("utf-8") elif isinstance(payload, bytes): message = payload else: raise ValueError(f"Unsupported payload type: {type(payload)}") - + # Generate HMAC-SHA256 signature - signature = hmac.new( - self.secret.encode('utf-8'), - message, - hashlib.sha256 - ).hexdigest() - + signature = hmac.new(self.secret.encode("utf-8"), message, hashlib.sha256).hexdigest() + return f"sha256={signature}" - + def verify_signature( - self, - payload: Union[Dict[str, Any], str, bytes], - received_signature: str + self, payload: Union[Dict[str, Any], str, bytes], received_signature: str ) -> bool: """ Verify HMAC-SHA256 signature for webhook payload - + Args: payload: Webhook payload (dict, string, or bytes) received_signature: Signature from webhook header - + Returns: True if signature is valid - + Raises: ValueError: If secret is not configured """ @@ -92,157 +87,140 @@ def verify_signature( logger.warning("Webhook signature verification skipped - no secret in debug mode") return True raise ValueError("Webhook secret not configured") - + try: # Generate expected signature expected_signature = self.generate_signature(payload) - + # Normalize received signature format - if not received_signature.startswith('sha256='): + if not received_signature.startswith("sha256="): received_signature = f"sha256={received_signature}" - + # Use constant-time comparison to prevent timing attacks is_valid = hmac.compare_digest(expected_signature, received_signature) - + if is_valid: logger.debug("Webhook signature verified successfully") else: logger.warning( "Webhook signature verification failed", expected=expected_signature[:16] + "...", - received=received_signature[:16] + "..." + received=received_signature[:16] + "...", ) - + return is_valid - + except Exception as e: - logger.error( - "Error during webhook signature verification", - error=str(e), - exc_info=True - ) + logger.error("Error during webhook signature verification", error=str(e), exc_info=True) return False - + def create_webhook_headers( - self, + self, payload: Union[Dict[str, Any], str, bytes], event_type: Optional[str] = None, - delivery_id: Optional[str] = None + delivery_id: Optional[str] = None, ) -> Dict[str, str]: """ Create HTTP headers for webhook delivery - + Args: payload: Webhook payload event_type: Type of event (e.g., 'scan.completed') delivery_id: Unique delivery ID for tracking - + Returns: Dictionary of headers to include in HTTP request """ signature = self.generate_signature(payload) - + headers = { "Content-Type": "application/json", "User-Agent": "OpenWatch-Webhook/1.0", "X-OpenWatch-Signature": signature, "X-Hub-Signature-256": signature, # GitHub compatible } - + if event_type: headers["X-OpenWatch-Event"] = event_type - + if delivery_id: headers["X-OpenWatch-Delivery"] = delivery_id - + return headers - + def extract_signature_from_headers(self, headers: Dict[str, str]) -> Optional[str]: """ Extract webhook signature from HTTP headers - + Args: headers: HTTP headers dictionary (case-insensitive) - + Returns: Signature string if found, None otherwise """ # Convert headers to lowercase for case-insensitive lookup lower_headers = {k.lower(): v for k, v in headers.items()} - + # Try common webhook signature header names possible_headers = [ "x-openwatch-signature", "x-hub-signature-256", "x-webhook-signature", - "x-signature-256" + "x-signature-256", ] - + for header_name in possible_headers: signature = lower_headers.get(header_name) if signature: logger.debug(f"Found webhook signature in header: {header_name}") return signature - + logger.debug("No webhook signature found in headers") return None - + def create_event_payload( - self, - event_type: str, - data: Dict[str, Any], - timestamp: Optional[str] = None + self, event_type: str, data: Dict[str, Any], timestamp: Optional[str] = None ) -> Dict[str, Any]: """ Create standardized webhook event payload - + Args: event_type: Type of event (e.g., 'scan.completed') data: Event-specific data timestamp: Event timestamp (ISO format) - + Returns: Standardized event payload """ from datetime import datetime - + if not timestamp: timestamp = datetime.utcnow().isoformat() - - return { - "event": event_type, - "timestamp": timestamp, - "data": data - } - + + return {"event": event_type, "timestamp": timestamp, "data": data} + def sign_api_request( - self, - method: str, - url: str, - payload: Optional[Union[Dict[str, Any], str, bytes]] = None + self, method: str, url: str, payload: Optional[Union[Dict[str, Any], str, bytes]] = None ) -> Dict[str, str]: """ Create signature headers for API requests to external services - + Args: method: HTTP method url: Request URL payload: Request payload - + Returns: Headers with signature information """ # For API requests, we might need different signing logic # This is a placeholder for future API request signing headers = {"User-Agent": "OpenWatch-API/1.0"} - + if payload: signature = self.generate_signature(payload) - headers.update({ - "X-OpenWatch-Signature": signature, - "Content-Type": "application/json" - }) - + headers.update({"X-OpenWatch-Signature": signature, "Content-Type": "application/json"}) + return headers @@ -262,27 +240,24 @@ def get_webhook_security() -> WebhookSecurity: def generate_webhook_signature(payload: Union[Dict[str, Any], str, bytes]) -> str: """ Generate signature for webhook payload - + Args: payload: Webhook payload - + Returns: Signature string in format "sha256=" """ return get_webhook_security().generate_signature(payload) -def verify_webhook_signature( - payload: Union[Dict[str, Any], str, bytes], - signature: str -) -> bool: +def verify_webhook_signature(payload: Union[Dict[str, Any], str, bytes], signature: str) -> bool: """ Verify webhook signature - + Args: payload: Webhook payload signature: Signature from webhook header - + Returns: True if signature is valid """ @@ -290,68 +265,52 @@ def verify_webhook_signature( def create_webhook_headers( - payload: Union[Dict[str, Any], str, bytes], - event_type: str, - delivery_id: Optional[str] = None + payload: Union[Dict[str, Any], str, bytes], event_type: str, delivery_id: Optional[str] = None ) -> Dict[str, str]: """ Create headers for webhook delivery - + Args: payload: Webhook payload event_type: Type of event delivery_id: Unique delivery ID - + Returns: Dictionary of headers """ return get_webhook_security().create_webhook_headers(payload, event_type, delivery_id) -def create_scan_completed_payload( - scan_id: str, - scan_data: Dict[str, Any] -) -> Dict[str, Any]: +def create_scan_completed_payload(scan_id: str, scan_data: Dict[str, Any]) -> Dict[str, Any]: """ Create scan.completed webhook payload - + Args: scan_id: Scan identifier scan_data: Scan result data - + Returns: Standardized webhook payload """ return get_webhook_security().create_event_payload( - "scan.completed", - { - "scan_id": scan_id, - **scan_data - } + "scan.completed", {"scan_id": scan_id, **scan_data} ) def create_scan_failed_payload( - scan_id: str, - scan_data: Dict[str, Any], - error_message: str + scan_id: str, scan_data: Dict[str, Any], error_message: str ) -> Dict[str, Any]: """ Create scan.failed webhook payload - + Args: scan_id: Scan identifier scan_data: Scan data error_message: Failure reason - + Returns: Standardized webhook payload """ return get_webhook_security().create_event_payload( - "scan.failed", - { - "scan_id": scan_id, - "error_message": error_message, - **scan_data - } - ) \ No newline at end of file + "scan.failed", {"scan_id": scan_id, "error_message": error_message, **scan_data} + ) diff --git a/backend/app/tasks/__init__.py b/backend/app/tasks/__init__.py index 8d22f84b..2c309a23 100644 --- a/backend/app/tasks/__init__.py +++ b/backend/app/tasks/__init__.py @@ -1 +1 @@ -# OpenWatch Celery Tasks Module \ No newline at end of file +# OpenWatch Celery Tasks Module diff --git a/backend/app/tasks/group_scan_tasks.py b/backend/app/tasks/group_scan_tasks.py index 725c9e56..5ad5bd77 100644 --- a/backend/app/tasks/group_scan_tasks.py +++ b/backend/app/tasks/group_scan_tasks.py @@ -1,6 +1,7 @@ """ Celery tasks for group scan orchestration """ + import logging import asyncio from typing import List, Dict, Any @@ -19,30 +20,30 @@ def execute_group_scan_task(session_id: str, group_id: int, scan_config: dict): Updates progress in real-time via database """ db = SessionLocal() - + try: logger.info(f"Starting group scan task for session {session_id}") - + # Initialize group scan service group_scan_service = GroupScanService(db) - + # Start the group scan execution result = asyncio.run(group_scan_service.start_group_scan_execution(session_id)) - + if result: logger.info(f"Group scan task completed successfully for session {session_id}") else: logger.warning(f"Group scan task had no pending hosts for session {session_id}") - + except Exception as e: logger.error(f"Group scan task failed for session {session_id}: {e}") - + # Update session status to failed try: asyncio.run(_update_session_status_to_failed(db, session_id, str(e))) except Exception as update_error: logger.error(f"Failed to update session status to failed: {update_error}") - + raise finally: db.close() @@ -52,38 +53,48 @@ async def _update_session_status_to_failed(db: Session, session_id: str, error_m """Update group scan session status to failed""" from sqlalchemy import text from datetime import datetime - + try: - db.execute(text(""" + db.execute( + text( + """ UPDATE group_scan_sessions SET status = 'failed', updated_at = :updated_at, completed_at = :completed_at, metadata = COALESCE(metadata, '{}'::jsonb) || :error_metadata::jsonb WHERE session_id = :session_id - """), { - "session_id": session_id, - "updated_at": datetime.utcnow(), - "completed_at": datetime.utcnow(), - "error_metadata": f'{{"error": "{error_message}"}}' - }) - + """ + ), + { + "session_id": session_id, + "updated_at": datetime.utcnow(), + "completed_at": datetime.utcnow(), + "error_metadata": f'{{"error": "{error_message}"}}', + }, + ) + # Also update any pending host statuses to failed - db.execute(text(""" + db.execute( + text( + """ UPDATE group_scan_host_progress SET status = 'failed', error_message = :error_message, updated_at = :updated_at WHERE session_id = :session_id AND status = 'pending' - """), { - "session_id": session_id, - "error_message": f"Group scan failed: {error_message}", - "updated_at": datetime.utcnow() - }) - + """ + ), + { + "session_id": session_id, + "error_message": f"Group scan failed: {error_message}", + "updated_at": datetime.utcnow(), + }, + ) + db.commit() logger.info(f"Updated session {session_id} status to failed") - + except Exception as e: logger.error(f"Failed to update session status: {e}") db.rollback() @@ -92,7 +103,7 @@ async def _update_session_status_to_failed(db: Session, session_id: str, error_m # Celery task wrapper (if Celery is available) try: from celery import current_app - + @current_app.task(bind=True) def execute_group_scan_celery_task(self, session_id: str, group_id: int, scan_config: dict): """Celery task wrapper for group scan execution""" @@ -101,31 +112,37 @@ def execute_group_scan_celery_task(self, session_id: str, group_id: int, scan_co db = SessionLocal() try: from sqlalchemy import text - db.execute(text(""" + + db.execute( + text( + """ UPDATE group_scan_sessions SET metadata = COALESCE(metadata, '{}'::jsonb) || :task_metadata::jsonb WHERE session_id = :session_id - """), { - "session_id": session_id, - "task_metadata": f'{{"celery_task_id": "{self.request.id}"}}' - }) + """ + ), + { + "session_id": session_id, + "task_metadata": f'{{"celery_task_id": "{self.request.id}"}}', + }, + ) db.commit() finally: db.close() - + # Execute group scan execute_group_scan_task(session_id, group_id, scan_config) - + except Exception as e: logger.error(f"Celery group scan task failed for session {session_id}: {e}") - + # Update session with failure db = SessionLocal() try: asyncio.run(_update_session_status_to_failed(db, session_id, str(e))) finally: db.close() - + raise except ImportError: @@ -138,22 +155,22 @@ def monitor_group_scan_progress(session_id: str): Can be used to handle cleanup, notifications, etc. """ db = SessionLocal() - + try: logger.debug(f"Monitoring group scan progress for session {session_id}") - + group_scan_service = GroupScanService(db) progress = asyncio.run(group_scan_service.get_scan_progress(session_id)) - + # Check if scan is complete and send notifications - if progress.status.value in ['completed', 'failed', 'cancelled']: + if progress.status.value in ["completed", "failed", "cancelled"]: logger.info(f"Group scan {session_id} finished with status: {progress.status.value}") - + # Here you could add webhook notifications, email alerts, etc. # For now, just log the completion - + return progress.dict() - + except Exception as e: logger.error(f"Error monitoring group scan progress: {e}") return None @@ -163,6 +180,7 @@ def monitor_group_scan_progress(session_id: str): # Celery task wrapper for monitoring try: + @current_app.task def monitor_group_scan_progress_celery_task(session_id: str): """Celery task wrapper for group scan monitoring""" @@ -170,4 +188,4 @@ def monitor_group_scan_progress_celery_task(session_id: str): except (ImportError, NameError): # Celery not available - pass \ No newline at end of file + pass diff --git a/backend/app/tasks/monitoring_tasks.py b/backend/app/tasks/monitoring_tasks.py index 1a5a9e13..318a47a2 100644 --- a/backend/app/tasks/monitoring_tasks.py +++ b/backend/app/tasks/monitoring_tasks.py @@ -1,6 +1,7 @@ """ Background tasks for host monitoring """ + import logging from celery import Celery from ..database import get_db @@ -8,6 +9,7 @@ logger = logging.getLogger(__name__) + def periodic_host_monitoring(): """ Periodic task to monitor all hosts @@ -15,32 +17,36 @@ def periodic_host_monitoring(): """ try: logger.info("Starting periodic host monitoring...") - + # Get database session db = next(get_db()) - + # Monitor all hosts import asyncio + results = asyncio.run(host_monitor.monitor_all_hosts(db)) - + # Log results - online_count = sum(1 for r in results if r['status'] == 'online') + online_count = sum(1 for r in results if r["status"] == "online") total_count = len(results) - + logger.info(f"Host monitoring completed: {online_count}/{total_count} hosts online") - + # Log any status changes for result in results: - if result.get('error_message'): - logger.warning(f"Host {result['hostname']} ({result['ip_address']}): {result['error_message']}") - + if result.get("error_message"): + logger.warning( + f"Host {result['hostname']} ({result['ip_address']}): {result['error_message']}" + ) + db.close() return f"Monitored {total_count} hosts, {online_count} online" - + except Exception as e: logger.error(f"Error in periodic host monitoring: {e}") return f"Error: {str(e)}" + # Example function to set up periodic monitoring with APScheduler def setup_host_monitoring_scheduler(): """ @@ -50,21 +56,21 @@ def setup_host_monitoring_scheduler(): try: from apscheduler.schedulers.background import BackgroundScheduler import atexit - + scheduler = BackgroundScheduler() - + # Don't auto-start or add jobs here - let restore_scheduler_state() handle it # This allows database configuration to control the scheduler behavior logger.info("Host monitoring scheduler instance created (not started)") - + # Shut down the scheduler when exiting the app atexit.register(lambda: scheduler.shutdown()) - + return scheduler - + except ImportError: logger.warning("APScheduler not available, periodic monitoring disabled") return None except Exception as e: logger.error(f"Failed to setup monitoring scheduler: {e}") - return None \ No newline at end of file + return None diff --git a/backend/app/tasks/scan_tasks.py b/backend/app/tasks/scan_tasks.py index 8bc1477d..8c0728cd 100644 --- a/backend/app/tasks/scan_tasks.py +++ b/backend/app/tasks/scan_tasks.py @@ -1,6 +1,7 @@ """ Celery tasks for SCAP scanning operations """ + import os import json import logging @@ -26,41 +27,45 @@ error_service = ErrorClassificationService() -def execute_scan_task(scan_id: str, host_data: Dict, content_path: str, - profile_id: str, scan_options: Dict): +def execute_scan_task( + scan_id: str, host_data: Dict, content_path: str, profile_id: str, scan_options: Dict +): """ Execute SCAP scan task This is designed to work with or without Celery """ db = SessionLocal() - + try: logger.info(f"Starting scan task: {scan_id}") - + # Update scan status to running - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET status = 'running', progress = 5 WHERE id = :scan_id - """), {"scan_id": scan_id}) + """ + ), + {"scan_id": scan_id}, + ) db.commit() - + # Check if this is part of a group scan and update group scan progress group_scan_session_id = scan_options.get("session_id") if scan_options else None if group_scan_session_id and scan_options.get("group_scan"): try: from ..services.group_scan_service import GroupScanProgressTracker + progress_tracker = GroupScanProgressTracker(db) progress_tracker.update_host_status( - group_scan_session_id, - host_data.get("id", "unknown"), - "scanning", - scan_id + group_scan_session_id, host_data.get("id", "unknown"), "scanning", scan_id ) logger.debug(f"Updated group scan progress for session {group_scan_session_id}") except Exception as e: logger.error(f"Failed to update group scan progress: {e}") # Don't fail the entire scan for group progress tracking errors - + # Decrypt credentials (handle both formats for compatibility) credentials = {} if host_data["hostname"] == "localhost": @@ -70,51 +75,58 @@ def execute_scan_task(scan_id: str, host_data: Dict, content_path: str, # Use centralized authentication service for all credential resolution try: from ..services.auth_service import get_auth_service + auth_service = get_auth_service(db) - + # Determine if we should use default credentials or host-specific use_default = host_data.get("auth_method") in ["default", "system_default"] target_id = None if use_default else host_data.get("id") - - logger.info(f"Resolving credentials for scan {scan_id}: use_default={use_default}, target_id={target_id}") - + + logger.info( + f"Resolving credentials for scan {scan_id}: use_default={use_default}, target_id={target_id}" + ) + # Resolve credentials using centralized service credential_data = auth_service.resolve_credential( - target_id=target_id, - use_default=use_default + target_id=target_id, use_default=use_default ) - + if not credential_data: logger.error(f"No credentials available for scan {scan_id}") _update_scan_error(db, scan_id, "No credentials available for host") return - + # Convert to format expected by scan tasks credentials = { "username": credential_data.username, "auth_method": credential_data.auth_method.value, "password": credential_data.password, "private_key": credential_data.private_key, # ✅ Consistent field naming - "private_key_passphrase": credential_data.private_key_passphrase + "private_key_passphrase": credential_data.private_key_passphrase, } - + # Update host_data to use resolved credentials host_data["username"] = credential_data.username host_data["auth_method"] = credential_data.auth_method.value - + logger.info(f"✅ Resolved {credential_data.source} credentials for scan {scan_id}") - + except Exception as e: logger.error(f"Failed to resolve credentials for scan {scan_id}: {e}") _update_scan_error(db, scan_id, f"Credential resolution failed: {str(e)}", e) return - + # Update progress - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET progress = 10 WHERE id = :scan_id - """), {"scan_id": scan_id}) + """ + ), + {"scan_id": scan_id}, + ) db.commit() - + # Extract the appropriate credential based on auth method if host_data["auth_method"] == "password": credential_value = credentials.get("password", "") @@ -122,22 +134,28 @@ def execute_scan_task(scan_id: str, host_data: Dict, content_path: str, credential_value = credentials.get("private_key", "") else: credential_value = credentials.get("credential", "") - + # Check if demo mode is enabled demo_mode = os.getenv("OPENWATCH_DEMO_MODE", "true").lower() == "true" - + if demo_mode: logger.info(f"Demo mode enabled - simulating scan execution for {scan_id}") - + # Simulate scan progression import time + for progress in [20, 40, 60, 80, 95]: - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET progress = :progress WHERE id = :scan_id - """), {"scan_id": scan_id, "progress": progress}) + """ + ), + {"scan_id": scan_id, "progress": progress}, + ) db.commit() time.sleep(0.5) # Simulate work - + # Create mock results result_data = { "scan_id": scan_id, @@ -151,24 +169,36 @@ def execute_scan_task(scan_id: str, host_data: Dict, content_path: str, "score": 80.0, "scan_time": "2025-08-06T03:45:00", "findings": [ - {"rule_id": "demo_rule_1", "severity": "high", "status": "fail", "description": "Sample security finding"}, - {"rule_id": "demo_rule_2", "severity": "medium", "status": "pass", "description": "Security control passed"} - ] + { + "rule_id": "demo_rule_1", + "severity": "high", + "status": "fail", + "description": "Sample security finding", + }, + { + "rule_id": "demo_rule_2", + "severity": "medium", + "status": "pass", + "description": "Security control passed", + }, + ], } - + # Update scan as completed - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET status = 'completed', progress = 100, completed_at = :completed_at WHERE id = :scan_id - """), { - "scan_id": scan_id, - "completed_at": datetime.utcnow() - }) + """ + ), + {"scan_id": scan_id, "completed_at": datetime.utcnow()}, + ) db.commit() - + # Send webhook notification for demo completion try: webhook_data = { @@ -179,23 +209,23 @@ def execute_scan_task(scan_id: str, host_data: Dict, content_path: str, "passed_rules": 120, "failed_rules": 25, "score": 80.0, - "completed_at": datetime.utcnow().isoformat() + "completed_at": datetime.utcnow().isoformat(), } - - asyncio.create_task( - send_scan_completed_webhook(scan_id, webhook_data) - ) + + asyncio.create_task(send_scan_completed_webhook(scan_id, webhook_data)) logger.debug(f"Webhook notification queued for demo scan: {scan_id}") except Exception as webhook_error: - logger.error(f"Failed to send demo completion webhook for scan {scan_id}: {webhook_error}") - + logger.error( + f"Failed to send demo completion webhook for scan {scan_id}: {webhook_error}" + ) + logger.info(f"Demo scan completed successfully: {scan_id}") return - + # Production mode - Test SSH connection first if host_data["hostname"] != "localhost": logger.info(f"Testing SSH connection for scan {scan_id}") - + # Extract the appropriate credential based on auth method if host_data["auth_method"] == "password": credential_value = credentials.get("password", "") @@ -203,55 +233,67 @@ def execute_scan_task(scan_id: str, host_data: Dict, content_path: str, credential_value = credentials.get("private_key", "") else: credential_value = credentials.get("credential", "") - + ssh_test = scap_scanner.test_ssh_connection( hostname=host_data["hostname"], port=host_data["port"], username=host_data["username"], auth_method=host_data["auth_method"], - credential=credential_value + credential=credential_value, ) - + if not ssh_test["success"]: logger.error(f"SSH connection failed for scan {scan_id}: {ssh_test['message']}") # Create a synthetic exception for SSH failure ssh_error = Exception(f"SSH connection failed: {ssh_test['message']}") - _update_scan_error(db, scan_id, f"SSH connection failed: {ssh_test['message']}", ssh_error) + _update_scan_error( + db, scan_id, f"SSH connection failed: {ssh_test['message']}", ssh_error + ) return - + if not ssh_test.get("oscap_available", False): logger.warning(f"OpenSCAP not available on remote host for scan {scan_id}") # Create a synthetic exception for missing dependency dep_error = Exception("OpenSCAP not available on remote host") _update_scan_error(db, scan_id, "OpenSCAP not available on remote host", dep_error) return - + # Update progress - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET progress = 20 WHERE id = :scan_id - """), {"scan_id": scan_id}) + """ + ), + {"scan_id": scan_id}, + ) db.commit() - + # Execute scan logger.info(f"Executing SCAP scan: {scan_id}") - + try: # Update progress to indicate scan execution has started - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET progress = 30 WHERE id = :scan_id - """), {"scan_id": scan_id}) + """ + ), + {"scan_id": scan_id}, + ) db.commit() - + # Extract rule_id from scan_options for rule-specific rescans rule_id = scan_options.get("rule_id") if scan_options else None - + if host_data["hostname"] == "localhost": # Local scan scan_results = scap_scanner.execute_local_scan( content_path=content_path, profile_id=profile_id, scan_id=scan_id, - rule_id=rule_id + rule_id=rule_id, ) else: # Remote scan @@ -264,15 +306,20 @@ def execute_scan_task(scan_id: str, host_data: Dict, content_path: str, content_path=content_path, profile_id=profile_id, scan_id=scan_id, - rule_id=rule_id + rule_id=rule_id, ) - + # Update progress after scan execution - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET progress = 90 WHERE id = :scan_id - """), {"scan_id": scan_id}) + """ + ), + {"scan_id": scan_id}, + ) db.commit() - + # Check for scan errors if "error" in scan_results: logger.error(f"Scan execution failed for {scan_id}: {scan_results['error']}") @@ -283,55 +330,65 @@ def execute_scan_task(scan_id: str, host_data: Dict, content_path: str, logger.error(f"Scan execution failed for {scan_id}: {str(e)}", exc_info=True) _update_scan_error(db, scan_id, f"Scan execution error: {str(e)}", e) return - + # Update scan record with results - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET status = 'completed', progress = 100, completed_at = :completed_at, result_file = :result_file, report_file = :report_file WHERE id = :scan_id - """), { - "scan_id": scan_id, - "completed_at": datetime.utcnow(), - "result_file": scan_results.get("xml_result"), - "report_file": scan_results.get("html_report") - }) + """ + ), + { + "scan_id": scan_id, + "completed_at": datetime.utcnow(), + "result_file": scan_results.get("xml_result"), + "report_file": scan_results.get("html_report"), + }, + ) db.commit() - + # Save scan results summary _save_scan_results(db, scan_id, scan_results) - + # Update group scan progress if this is part of a group scan if group_scan_session_id and scan_options.get("group_scan"): try: from ..services.group_scan_service import GroupScanProgressTracker + progress_tracker = GroupScanProgressTracker(db) - + # Get scan result ID for linking scan_result_id = None - result_query = db.execute(text("SELECT id FROM scan_results WHERE scan_id = :scan_id"), - {"scan_id": scan_id}).fetchone() + result_query = db.execute( + text("SELECT id FROM scan_results WHERE scan_id = :scan_id"), + {"scan_id": scan_id}, + ).fetchone() if result_query: scan_result_id = str(result_query.id) - + progress_tracker.update_host_status( - group_scan_session_id, - host_data.get("id", "unknown"), - "completed", + group_scan_session_id, + host_data.get("id", "unknown"), + "completed", scan_id, - scan_result_id + scan_result_id, + ) + logger.debug( + f"Updated group scan progress to completed for session {group_scan_session_id}" ) - logger.debug(f"Updated group scan progress to completed for session {group_scan_session_id}") except Exception as e: logger.error(f"Failed to update group scan completion progress: {e}") - + # Process scan with semantic intelligence try: asyncio.run(_process_semantic_intelligence(db, scan_id, scan_results, host_data)) except Exception as e: logger.error(f"Semantic intelligence processing failed for {scan_id}: {e}") # Continue with normal flow - don't break existing functionality - + # Send webhook notification for successful completion try: # Prepare scan data for webhook @@ -343,19 +400,17 @@ def execute_scan_task(scan_id: str, host_data: Dict, content_path: str, "passed_rules": scan_results.get("rules_passed", 0), "failed_rules": scan_results.get("rules_failed", 0), "score": scan_results.get("score", 0), - "completed_at": datetime.utcnow().isoformat() + "completed_at": datetime.utcnow().isoformat(), } - + # Send webhook asynchronously - asyncio.create_task( - send_scan_completed_webhook(scan_id, webhook_data) - ) + asyncio.create_task(send_scan_completed_webhook(scan_id, webhook_data)) logger.debug(f"Webhook notification queued for completed scan: {scan_id}") except Exception as webhook_error: logger.error(f"Failed to send completion webhook for scan {scan_id}: {webhook_error}") - + logger.info(f"Scan completed successfully: {scan_id}") - + except ScanExecutionError as e: logger.error(f"Scan execution error for {scan_id}: {e}") _update_scan_error(db, scan_id, str(e), e) @@ -366,7 +421,9 @@ def execute_scan_task(scan_id: str, host_data: Dict, content_path: str, db.close() -def _update_scan_error(db: Session, scan_id: str, error_message: str, original_exception: Exception = None): +def _update_scan_error( + db: Session, scan_id: str, error_message: str, original_exception: Exception = None +): """Update scan with error status and set progress to 100% to indicate completion""" try: # Classify error if original exception provided @@ -374,57 +431,76 @@ def _update_scan_error(db: Session, scan_id: str, error_message: str, original_e if original_exception: try: import asyncio - classified_error = asyncio.run(error_service.classify_error(original_exception, {"scan_id": scan_id})) + + classified_error = asyncio.run( + error_service.classify_error(original_exception, {"scan_id": scan_id}) + ) # Use classified error message if available if classified_error: - error_message = f"{classified_error.message} (Code: {classified_error.error_code})" - logger.info(f"Error classified for scan {scan_id}: {classified_error.category.value} - {classified_error.error_code}") + error_message = ( + f"{classified_error.message} (Code: {classified_error.error_code})" + ) + logger.info( + f"Error classified for scan {scan_id}: {classified_error.category.value} - {classified_error.error_code}" + ) except Exception as e: logger.warning(f"Failed to classify error for scan {scan_id}: {e}") # Continue with original error message - + # Get scan data for webhook notification and check for group scan - scan_result = db.execute(text(""" + scan_result = db.execute( + text( + """ SELECT s.id, h.hostname, s.profile_id, s.scan_options, s.host_id FROM scans s JOIN hosts h ON s.host_id = h.id WHERE s.id = :scan_id - """), {"scan_id": scan_id}) - + """ + ), + {"scan_id": scan_id}, + ) + scan_data = scan_result.fetchone() - + # Check if this is part of a group scan and update progress if scan_data and scan_data.scan_options: try: import json + scan_options = json.loads(scan_data.scan_options) group_scan_session_id = scan_options.get("session_id") - + if group_scan_session_id and scan_options.get("group_scan"): from ..services.group_scan_service import GroupScanProgressTracker + progress_tracker = GroupScanProgressTracker(db) - asyncio.run(progress_tracker.update_host_status( - group_scan_session_id, - str(scan_data.host_id), - "failed", - scan_id, - error_message=error_message - )) - logger.debug(f"Updated group scan progress to failed for session {group_scan_session_id}") + asyncio.run( + progress_tracker.update_host_status( + group_scan_session_id, + str(scan_data.host_id), + "failed", + scan_id, + error_message=error_message, + ) + ) + logger.debug( + f"Updated group scan progress to failed for session {group_scan_session_id}" + ) except Exception as e: logger.error(f"Failed to update group scan failure progress: {e}") - - db.execute(text(""" + + db.execute( + text( + """ UPDATE scans SET status = 'failed', progress = 100, completed_at = :completed_at, error_message = :error_message WHERE id = :scan_id - """), { - "scan_id": scan_id, - "completed_at": datetime.utcnow(), - "error_message": error_message - }) + """ + ), + {"scan_id": scan_id, "completed_at": datetime.utcnow(), "error_message": error_message}, + ) db.commit() - + # Send webhook notification for scan failure if scan_data: try: @@ -432,16 +508,14 @@ def _update_scan_error(db: Session, scan_id: str, error_message: str, original_e "hostname": scan_data.hostname, "profile_id": scan_data.profile_id, "status": "failed", - "completed_at": datetime.utcnow().isoformat() + "completed_at": datetime.utcnow().isoformat(), } - - asyncio.create_task( - send_scan_failed_webhook(scan_id, webhook_data, error_message) - ) + + asyncio.create_task(send_scan_failed_webhook(scan_id, webhook_data, error_message)) logger.debug(f"Webhook notification queued for failed scan: {scan_id}") except Exception as webhook_error: logger.error(f"Failed to send failure webhook for scan {scan_id}: {webhook_error}") - + except Exception as e: logger.error(f"Failed to update scan error status: {e}") @@ -454,9 +528,11 @@ def _save_scan_results(db: Session, scan_id: str, scan_results: Dict): severity_high = len([r for r in failed_rules if r.get("severity") == "high"]) severity_medium = len([r for r in failed_rules if r.get("severity") == "medium"]) severity_low = len([r for r in failed_rules if r.get("severity") == "low"]) - + # Insert scan results - db.execute(text(""" + db.execute( + text( + """ INSERT INTO scan_results (scan_id, total_rules, passed_rules, failed_rules, error_rules, unknown_rules, not_applicable_rules, score, severity_high, @@ -464,24 +540,27 @@ def _save_scan_results(db: Session, scan_id: str, scan_results: Dict): VALUES (:scan_id, :total_rules, :passed_rules, :failed_rules, :error_rules, :unknown_rules, :not_applicable_rules, :score, :severity_high, :severity_medium, :severity_low, :created_at) - """), { - "scan_id": scan_id, - "total_rules": scan_results.get("rules_total", 0), - "passed_rules": scan_results.get("rules_passed", 0), - "failed_rules": scan_results.get("rules_failed", 0), - "error_rules": scan_results.get("rules_error", 0), - "unknown_rules": scan_results.get("rules_unknown", 0), - "not_applicable_rules": scan_results.get("rules_notapplicable", 0), - "score": f"{scan_results.get('score', 0):.1f}%", - "severity_high": severity_high, - "severity_medium": severity_medium, - "severity_low": severity_low, - "created_at": datetime.utcnow() - }) + """ + ), + { + "scan_id": scan_id, + "total_rules": scan_results.get("rules_total", 0), + "passed_rules": scan_results.get("rules_passed", 0), + "failed_rules": scan_results.get("rules_failed", 0), + "error_rules": scan_results.get("rules_error", 0), + "unknown_rules": scan_results.get("rules_unknown", 0), + "not_applicable_rules": scan_results.get("rules_notapplicable", 0), + "score": f"{scan_results.get('score', 0):.1f}%", + "severity_high": severity_high, + "severity_medium": severity_medium, + "severity_low": severity_low, + "created_at": datetime.utcnow(), + }, + ) db.commit() - + logger.info(f"Scan results saved for {scan_id}") - + except Exception as e: logger.error(f"Failed to save scan results for {scan_id}: {e}") @@ -489,23 +568,29 @@ def _save_scan_results(db: Session, scan_id: str, scan_results: Dict): # Celery task wrapper (if Celery is available) try: from celery import current_app - + @current_app.task(bind=True) - def execute_scan_celery_task(self, scan_id: str, host_data: Dict, content_path: str, - profile_id: str, scan_options: Dict): + def execute_scan_celery_task( + self, scan_id: str, host_data: Dict, content_path: str, profile_id: str, scan_options: Dict + ): """Celery task wrapper for scan execution""" try: # Update task ID in database db = SessionLocal() - db.execute(text(""" + db.execute( + text( + """ UPDATE scans SET celery_task_id = :task_id WHERE id = :scan_id - """), {"task_id": self.request.id, "scan_id": scan_id}) + """ + ), + {"task_id": self.request.id, "scan_id": scan_id}, + ) db.commit() db.close() - + # Execute scan execute_scan_task(scan_id, host_data, content_path, profile_id, scan_options) - + except Exception as e: logger.error(f"Celery task failed for scan {scan_id}: {e}") # Update scan with failure @@ -519,19 +604,16 @@ def execute_scan_celery_task(self, scan_id: str, host_data: Dict, content_path: async def _process_semantic_intelligence( - db: Session, - scan_id: str, - scan_results: Dict[str, Any], - host_data: Dict[str, Any] + db: Session, scan_id: str, scan_results: Dict[str, Any], host_data: Dict[str, Any] ): """Process scan results with semantic intelligence""" - + try: logger.info(f"Starting semantic intelligence processing for scan: {scan_id}") - + # Get semantic SCAP engine semantic_engine = get_semantic_scap_engine() - + # Build host information for semantic processing host_info = { "host_id": host_data.get("host_id"), @@ -540,77 +622,84 @@ async def _process_semantic_intelligence( "distribution_version": host_data.get("distribution_version"), "os_version": host_data.get("os_version", ""), "package_manager": host_data.get("package_manager"), - "service_manager": host_data.get("service_manager") + "service_manager": host_data.get("service_manager"), } - + # Process scan with semantic intelligence intelligent_result = await semantic_engine.process_scan_with_intelligence( - scan_results=scan_results, - scan_id=scan_id, - host_info=host_info + scan_results=scan_results, scan_id=scan_id, host_info=host_info ) - + # Update scan record with semantic analysis information frameworks_analyzed = list(intelligent_result.framework_compliance_matrix.keys()) semantic_rules_count = len(intelligent_result.semantic_rules) - - db.execute(text(""" + + db.execute( + text( + """ UPDATE scans SET semantic_analysis_completed = true, semantic_rules_count = :semantic_rules_count, frameworks_analyzed = :frameworks_analyzed, remediation_strategy = :remediation_strategy WHERE id = :scan_id - """), { - "scan_id": scan_id, - "semantic_rules_count": semantic_rules_count, - "frameworks_analyzed": frameworks_analyzed, - "remediation_strategy": json.dumps(intelligent_result.remediation_strategy) - }) + """ + ), + { + "scan_id": scan_id, + "semantic_rules_count": semantic_rules_count, + "frameworks_analyzed": frameworks_analyzed, + "remediation_strategy": json.dumps(intelligent_result.remediation_strategy), + }, + ) db.commit() - + # Send enhanced webhook with semantic intelligence await _send_enhanced_semantic_webhook(scan_id, intelligent_result, host_data) - - logger.info(f"Semantic intelligence processing completed for scan {scan_id}: " - f"{semantic_rules_count} semantic rules, " - f"{len(frameworks_analyzed)} frameworks analyzed") - + + logger.info( + f"Semantic intelligence processing completed for scan {scan_id}: " + f"{semantic_rules_count} semantic rules, " + f"{len(frameworks_analyzed)} frameworks analyzed" + ) + except Exception as e: logger.error(f"Error in semantic intelligence processing: {e}", exc_info=True) # Don't re-raise - we want to continue with normal scan processing async def _send_enhanced_semantic_webhook( - scan_id: str, - intelligent_result: 'IntelligentScanResult', - host_data: Dict[str, Any] + scan_id: str, intelligent_result: "IntelligentScanResult", host_data: Dict[str, Any] ): """Send enhanced webhook with semantic intelligence data""" - + try: from .webhook_tasks import deliver_webhook - + # Get active webhook endpoints for semantic events db = SessionLocal() try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, url, secret_hash FROM webhook_endpoints WHERE is_active = true AND ( event_types::jsonb ? 'semantic.analysis.completed' OR event_types::jsonb ? 'scan.completed' ) - """)) - + """ + ) + ) + webhooks = result.fetchall() finally: db.close() - + if not webhooks: logger.debug("No active webhooks configured for semantic events") return - + # Create enhanced webhook payload with semantic intelligence webhook_data = { "event": "semantic.analysis.completed", @@ -624,11 +713,13 @@ async def _send_enhanced_semantic_webhook( "distribution_name": host_data.get("distribution_name", "unknown"), "distribution_version": host_data.get("distribution_version", "unknown"), "package_manager": host_data.get("package_manager", "unknown"), - "service_manager": host_data.get("service_manager", "unknown") + "service_manager": host_data.get("service_manager", "unknown"), }, "semantic_analysis": { "semantic_rules_count": len(intelligent_result.semantic_rules), - "frameworks_analyzed": list(intelligent_result.framework_compliance_matrix.keys()), + "frameworks_analyzed": list( + intelligent_result.framework_compliance_matrix.keys() + ), "framework_compliance_matrix": intelligent_result.framework_compliance_matrix, "remediation_strategy": intelligent_result.remediation_strategy, "semantic_rules": [ @@ -642,87 +733,86 @@ async def _send_enhanced_semantic_webhook( "frameworks": rule.frameworks, "remediation_complexity": rule.remediation_complexity, "estimated_fix_time": rule.estimated_fix_time, - "remediation_available": rule.remediation_available + "remediation_available": rule.remediation_available, } - for rule in intelligent_result.semantic_rules[:10] # Limit to avoid large payloads - ] + for rule in intelligent_result.semantic_rules[ + :10 + ] # Limit to avoid large payloads + ], }, "original_scan_results": { "total_rules": intelligent_result.original_results.get("rules_total", 0), "passed_rules": intelligent_result.original_results.get("rules_passed", 0), "failed_rules": intelligent_result.original_results.get("rules_failed", 0), - "score": intelligent_result.original_results.get("score", 0) - } - } + "score": intelligent_result.original_results.get("score", 0), + }, + }, } - + # Send to all configured endpoints for webhook in webhooks: try: await deliver_webhook( - webhook.url, - webhook.secret_hash, - webhook_data, - str(webhook.id) + webhook.url, webhook.secret_hash, webhook_data, str(webhook.id) ) except Exception as e: logger.error(f"Failed to deliver semantic webhook to {webhook.url}: {e}") - + except Exception as e: logger.error(f"Error sending enhanced semantic webhook: {e}") def _extract_host_distribution_info(host_data: Dict[str, Any]) -> Dict[str, str]: """Extract and normalize host distribution information""" - + # Try to extract distribution info from various sources hostname = host_data.get("hostname", "") os_version = host_data.get("os_version", "").lower() - + # Default values distribution_info = { "distribution_family": "unknown", - "distribution_name": "unknown", + "distribution_name": "unknown", "distribution_version": "unknown", "package_manager": "unknown", - "service_manager": "systemd" # Most modern systems use systemd + "service_manager": "systemd", # Most modern systems use systemd } - + # Detect from OS version string if "rhel" in os_version or "red hat" in os_version: distribution_info["distribution_family"] = "redhat" distribution_info["distribution_name"] = "rhel" distribution_info["package_manager"] = "dnf" - + # Extract version - version_match = re.search(r'(\d+)', os_version) + version_match = re.search(r"(\d+)", os_version) if version_match: distribution_info["distribution_version"] = version_match.group(1) - + elif "ubuntu" in os_version: distribution_info["distribution_family"] = "debian" distribution_info["distribution_name"] = "ubuntu" distribution_info["package_manager"] = "apt" - + # Extract version - version_match = re.search(r'(\d+\.\d+)', os_version) + version_match = re.search(r"(\d+\.\d+)", os_version) if version_match: distribution_info["distribution_version"] = version_match.group(1) - + elif "centos" in os_version: distribution_info["distribution_family"] = "redhat" distribution_info["distribution_name"] = "centos" distribution_info["package_manager"] = "yum" - + elif "oracle" in os_version: distribution_info["distribution_family"] = "redhat" distribution_info["distribution_name"] = "oracle" distribution_info["package_manager"] = "dnf" - + # Update host_data with distribution info (non-destructive) for key, value in distribution_info.items(): if key not in host_data or not host_data[key]: host_data[key] = value - + return distribution_info - pass \ No newline at end of file + pass diff --git a/backend/app/tasks/webhook_tasks.py b/backend/app/tasks/webhook_tasks.py index 8c7aa3df..03f74098 100644 --- a/backend/app/tasks/webhook_tasks.py +++ b/backend/app/tasks/webhook_tasks.py @@ -2,6 +2,7 @@ Webhook Delivery Tasks Background tasks for delivering webhooks to AEGIS and other integrations """ + import json import uuid import hashlib @@ -16,9 +17,9 @@ from ..database import get_db from ..services.http_client import get_webhook_client from ..services.webhook_security import ( - create_webhook_headers, + create_webhook_headers, create_scan_completed_payload, - create_scan_failed_payload + create_scan_failed_payload, ) from ..services.integration_metrics import record_webhook_delivery, time_webhook_delivery @@ -26,149 +27,153 @@ async def deliver_webhook( - url: str, - secret_hash: str, - event_data: Dict[str, Any], - webhook_id: str, - max_retries: int = 3 + url: str, secret_hash: str, event_data: Dict[str, Any], webhook_id: str, max_retries: int = 3 ) -> bool: """ Deliver webhook to endpoint with signature verification - + Args: url: Target webhook URL secret_hash: Hashed webhook secret for signature generation event_data: Event payload to send webhook_id: Webhook endpoint ID for tracking max_retries: Maximum retry attempts - + Returns: bool: True if delivery successful, False otherwise """ # Create delivery record delivery_id = str(uuid.uuid4()) - + try: db = next(get_db()) try: - db.execute(text(""" + db.execute( + text( + """ INSERT INTO webhook_deliveries (id, webhook_id, event_type, event_data, delivery_status, created_at) VALUES (:id, :webhook_id, :event_type, :event_data, :delivery_status, :created_at) - """), { - "id": delivery_id, - "webhook_id": webhook_id, - "event_type": event_data.get("event_type", "unknown"), - "event_data": json.dumps(event_data), - "delivery_status": "pending", - "created_at": datetime.utcnow() - }) + """ + ), + { + "id": delivery_id, + "webhook_id": webhook_id, + "event_type": event_data.get("event_type", "unknown"), + "event_data": json.dumps(event_data), + "delivery_status": "pending", + "created_at": datetime.utcnow(), + }, + ) db.commit() finally: db.close() except Exception as e: - logger.error("Failed to create webhook delivery record", error=str(e), webhook_id=webhook_id) + logger.error( + "Failed to create webhook delivery record", error=str(e), webhook_id=webhook_id + ) return False - + # Create webhook headers with signature headers = create_webhook_headers( - event_data, - event_data.get("event_type", "unknown"), - delivery_id + event_data, event_data.get("event_type", "unknown"), delivery_id ) - + # Get webhook client webhook_client = await get_webhook_client() - + # Time the webhook delivery operation start_time = time.time() success = False error_msg = None - + # Attempt delivery using enhanced HTTP client (it has built-in retries) try: - response = await webhook_client.deliver_webhook( - url, - event_data, - headers - ) - + response = await webhook_client.deliver_webhook(url, event_data, headers) + success = True duration = time.time() - start_time - + # Record successful delivery metrics record_webhook_delivery( success=True, duration=duration, - target_service=url.split('/')[2], # Extract domain from URL - error=None + target_service=url.split("/")[2], # Extract domain from URL + error=None, ) - + # Update delivery record with success db = next(get_db()) try: - db.execute(text(""" + db.execute( + text( + """ UPDATE webhook_deliveries SET delivery_status = 'delivered', http_status_code = :status_code, response_body = :response_body, delivered_at = :delivered_at WHERE id = :id - """), { - "id": delivery_id, - "status_code": response.status_code, - "response_body": response.text[:1000], # Truncate long responses - "delivered_at": datetime.utcnow() - }) + """ + ), + { + "id": delivery_id, + "status_code": response.status_code, + "response_body": response.text[:1000], # Truncate long responses + "delivered_at": datetime.utcnow(), + }, + ) db.commit() finally: db.close() - + logger.info( "Webhook delivered successfully", webhook_id=webhook_id, delivery_id=delivery_id, status_code=response.status_code, - duration_ms=round(duration * 1000, 2) + duration_ms=round(duration * 1000, 2), ) return True - + except Exception as e: # Record failed delivery metrics error_msg = str(e) duration = time.time() - start_time - + record_webhook_delivery( success=False, duration=duration, - target_service=url.split('/')[2], # Extract domain from URL - error=error_msg + target_service=url.split("/")[2], # Extract domain from URL + error=error_msg, ) - + # Update delivery record with failure db = next(get_db()) try: - db.execute(text(""" + db.execute( + text( + """ UPDATE webhook_deliveries SET delivery_status = 'failed', error_message = :error_message WHERE id = :id - """), { - "id": delivery_id, - "error_message": error_msg - }) + """ + ), + {"id": delivery_id, "error_message": error_msg}, + ) db.commit() finally: db.close() - + logger.error( "Webhook delivery failed", webhook_id=webhook_id, delivery_id=delivery_id, error=error_msg, - duration_ms=round(duration * 1000, 2) + duration_ms=round(duration * 1000, 2), ) - + return False @@ -178,39 +183,38 @@ async def send_scan_completed_webhook(scan_id: str, scan_data: Dict[str, Any]): # Get active webhook endpoints that listen for scan.completed events db = next(get_db()) try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, url, secret_hash FROM webhook_endpoints WHERE is_active = true AND event_types::jsonb ? 'scan.completed' - """)) - + """ + ) + ) + webhooks = result.fetchall() finally: db.close() - + if not webhooks: logger.info("No active webhooks configured for scan.completed events") return - + # Create standardized event payload event_data = create_scan_completed_payload(scan_id, scan_data) - + # Send to all registered endpoints for webhook in webhooks: try: - await deliver_webhook( - webhook.url, - webhook.secret_hash, - event_data, - str(webhook.id) - ) + await deliver_webhook(webhook.url, webhook.secret_hash, event_data, str(webhook.id)) except Exception as e: logger.error( "Failed to deliver scan.completed webhook", webhook_id=str(webhook.id), - error=str(e) + error=str(e), ) - + except Exception as e: logger.error("Failed to process scan.completed webhooks", scan_id=scan_id, error=str(e)) @@ -221,38 +225,37 @@ async def send_scan_failed_webhook(scan_id: str, scan_data: Dict[str, Any], erro # Get active webhook endpoints that listen for scan.failed events db = next(get_db()) try: - result = db.execute(text(""" + result = db.execute( + text( + """ SELECT id, url, secret_hash FROM webhook_endpoints WHERE is_active = true AND event_types::jsonb ? 'scan.failed' - """)) - + """ + ) + ) + webhooks = result.fetchall() finally: db.close() - + if not webhooks: logger.info("No active webhooks configured for scan.failed events") return - + # Create standardized event payload event_data = create_scan_failed_payload(scan_id, scan_data, error_message) - + # Send to all registered endpoints for webhook in webhooks: try: - await deliver_webhook( - webhook.url, - webhook.secret_hash, - event_data, - str(webhook.id) - ) + await deliver_webhook(webhook.url, webhook.secret_hash, event_data, str(webhook.id)) except Exception as e: logger.error( "Failed to deliver scan.failed webhook", webhook_id=str(webhook.id), - error=str(e) + error=str(e), ) - + except Exception as e: - logger.error("Failed to process scan.failed webhooks", scan_id=scan_id, error=str(e)) \ No newline at end of file + logger.error("Failed to process scan.failed webhooks", scan_id=scan_id, error=str(e)) diff --git a/docs/security/SECURITY.md b/docs/security/SECURITY.md new file mode 100644 index 00000000..84e7a2b0 --- /dev/null +++ b/docs/security/SECURITY.md @@ -0,0 +1,193 @@ +# OpenWatch Security Architecture + +## Overview + +OpenWatch implements a comprehensive FIPS 140-2 compliant security architecture for secure OpenSCAP scanning operations. This document outlines the security controls, cryptographic implementations, and compliance measures. + +## FIPS 140-2 Compliance + +### Cryptographic Modules +- **AES-256-GCM**: Data encryption at rest and in transit +- **RSA-2048**: Digital signatures for JWT tokens +- **SHA-256**: Hash functions and key derivation +- **Argon2id**: Password hashing (FIPS approved) +- **PBKDF2**: Key derivation with 100,000 iterations + +### Validation Status +- OpenSSL FIPS module validation required for production +- Cryptographic operations use only FIPS-approved algorithms +- Runtime FIPS mode validation on application startup + +## Security Architecture + +### Transport Security +``` +┌─────────────────────────────────────────────────────────────┐ +│ TLS 1.3 Layer │ +│ ┌─────────────────────────────────────────────────────────┤ +│ │ Application Security Layer │ +│ │ ┌─────────────────────────────────────────────────────┤ +│ │ │ Data Security Layer │ +│ │ │ ┌─────────────────────────────────────────────────┤ +│ │ │ │ Infrastructure Security │ +└──┴──┴──┴─────────────────────────────────────────────────────┘ +``` + +### Authentication Flow +1. **User Authentication**: RSA-2048 signed JWT tokens +2. **Service Authentication**: Mutual TLS between services +3. **SSH Authentication**: Encrypted private keys for remote scans +4. **Database Authentication**: SCRAM-SHA-256 with TLS + +### Authorization Model +- **Role-Based Access Control (RBAC)** + - `admin`: Full system access + - `user`: Limited scan operations +- **Resource-Level Permissions**: Host and scan access controls +- **API Endpoint Protection**: JWT token validation required + +## Data Protection + +### Encryption at Rest +- **Database**: PostgreSQL with TDE (Transparent Data Encryption) +- **Credentials**: AES-256-GCM encryption for SSH keys and passwords +- **Files**: SCAP content and results encrypted on disk +- **Logs**: Audit logs with integrity protection + +### Encryption in Transit +- **HTTPS/TLS 1.3**: All client communications +- **Database TLS**: Encrypted PostgreSQL connections +- **Redis TLS**: Secure Celery message passing +- **SSH**: OpenSCAP remote scanning operations + +### Key Management +- **Master Key**: Environment-based encryption key +- **JWT Keys**: RSA-2048 key pair for token signing +- **TLS Certificates**: X.509 certificates for service communication +- **SSH Keys**: Per-host encrypted private keys + +## Network Security + +### Network Segmentation +``` +Internet ┌──────────────────────────────────────────────────────┐ + ↓ │ Load Balancer │ +┌─────────┼──────────────────────────────────────────────────────┤ +│ DMZ │ Frontend (HTTPS) │ +├─────────┼──────────────────────────────────────────────────────┤ +│ App Tier│ Backend API (mTLS) │ +├─────────┼──────────────────────────────────────────────────────┤ +│Data Tier│ Database + Redis (Encrypted) │ +└─────────┴──────────────────────────────────────────────────────┘ +``` + +### Security Headers +- **HSTS**: HTTP Strict Transport Security +- **CSP**: Content Security Policy +- **X-Frame-Options**: Clickjacking protection +- **X-Content-Type-Options**: MIME sniffing protection + +## Audit and Compliance + +### Security Logging +- **Authentication Events**: Login attempts and failures +- **Authorization Events**: Access control decisions +- **Scan Operations**: All OpenSCAP operations logged +- **System Events**: Configuration changes and errors + +### Audit Trail +- **Tamper Evident**: Cryptographic integrity protection +- **Non-Repudiation**: Digital signatures on critical events +- **Retention**: Configurable log retention periods +- **Export**: SIEM integration capabilities + +### Compliance Reporting +- **FIPS Validation**: Real-time compliance status +- **Security Metrics**: Authentication and authorization metrics +- **Vulnerability Scanning**: Regular security assessments +- **Penetration Testing**: Periodic security validation + +## Secure Development + +### Security Testing +- **Static Analysis**: Bandit security linting +- **Dependency Scanning**: Safety vulnerability checks +- **Secret Detection**: Pre-commit hook scanning +- **Dynamic Analysis**: Runtime security testing + +### Code Security +- **Input Validation**: All user inputs sanitized +- **SQL Injection**: Parameterized queries only +- **XSS Protection**: Output encoding and CSP +- **CSRF Protection**: Token-based protection + +## Operational Security + +### Container Security +- **Base Images**: FIPS-compliant Red Hat UBI +- **Vulnerability Scanning**: Regular image updates +- **Runtime Security**: Non-root container execution +- **Resource Limits**: CPU and memory constraints + +### Infrastructure Security +- **Secrets Management**: Environment-based configuration +- **Access Control**: Principle of least privilege +- **Network Policies**: Kubernetes network segmentation +- **Monitoring**: Security event monitoring + +## Incident Response + +### Security Monitoring +- **Failed Authentication**: Account lockout after 5 attempts +- **Unusual Activity**: Anomaly detection and alerting +- **System Health**: Continuous security posture monitoring +- **Threat Detection**: Real-time security event analysis + +### Response Procedures +1. **Detection**: Automated security event detection +2. **Analysis**: Security team investigation +3. **Containment**: Threat isolation and mitigation +4. **Eradication**: Root cause remediation +5. **Recovery**: Service restoration +6. **Lessons Learned**: Post-incident review + +## Configuration Management + +### Security Baselines +- **CIS Benchmarks**: Container and OS hardening +- **NIST Guidelines**: Security control implementation +- **DISA STIGs**: Military security requirements +- **Custom Policies**: Organization-specific controls + +### Secure Defaults +- **Encryption Enabled**: All data encrypted by default +- **Strong Authentication**: Multi-factor authentication required +- **Least Privilege**: Minimal permission grants +- **Audit Logging**: Comprehensive security logging + +## Disaster Recovery + +### Backup Security +- **Encrypted Backups**: AES-256 encryption for all backups +- **Key Escrow**: Secure key recovery procedures +- **Offsite Storage**: Geographically distributed backups +- **Recovery Testing**: Regular disaster recovery drills + +### Business Continuity +- **High Availability**: Multi-instance deployments +- **Failover Procedures**: Automated service failover +- **Data Replication**: Real-time data synchronization +- **Recovery Objectives**: RTO < 4 hours, RPO < 1 hour + +## Security Contacts + +- **Security Team**: security@openwatch.example.com +- **Incident Response**: incident@openwatch.example.com +- **Vulnerability Reports**: security-issues@openwatch.example.com + +## References + +- [NIST SP 800-53](https://csrc.nist.gov/publications/detail/sp/800-53/rev-5/final) +- [FIPS 140-2](https://csrc.nist.gov/publications/detail/fips/140/2/final) +- [OWASP Security Guidelines](https://owasp.org/www-project-application-security-verification-standard/) +- [CIS Controls](https://www.cisecurity.org/controls) \ No newline at end of file diff --git a/fix_async_issues.py b/fix_async_issues.py new file mode 100644 index 00000000..c6e6170f --- /dev/null +++ b/fix_async_issues.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +""" +Script to fix async/await issues reported by SonarCloud +Removes async keyword from functions that don't use async features +""" +import re +import os +from pathlib import Path + +# Files with async issues based on SonarCloud report +FILES_TO_FIX = { + "backend/app/audit_db.py": [14], + "backend/app/auth.py": [277, 289, 408], + "backend/app/celery_app.py": [185], + "backend/app/database.py": [400, 410], + "backend/app/middleware/authorization_middleware.py": [209, 356, 376, 427], + "backend/app/middleware/metrics.py": [104, 164], + "backend/app/plugins/interface.py": [110, 145, 164, 189, 208, 231, 250, 269], + "backend/app/plugins/manager.py": [187, 196, 319, 341, 381], + "backend/app/rbac.py": [331], + "backend/app/routes/audit.py": [255], + "backend/app/routes/capabilities.py": [195, 241, 295, 306, 373], + "backend/app/routes/mfa.py": [89], + "backend/app/routes/rule_scanning.py": [368, 404, 443], + "backend/app/routes/system_settings.py": [545], + "backend/app/routes/system_settings_unified.py": [815], + "backend/app/routes/v1/remediation.py": [559], + "backend/app/services/authorization_service.py": [326, 460, 524, 568, 583, 648, 816, 882], + "backend/app/services/bulk_scan_orchestrator.py": [412, 460, 491, 521, 553, 595, 729, 772, 804], + "backend/app/services/command_sandbox.py": [146, 168, 191, 380, 408], +} + +def remove_async_from_function(content: str, line_number: int) -> str: + """Remove async keyword from a specific function""" + lines = content.split('\n') + + # Adjust for 0-based indexing + idx = line_number - 1 + + if idx < len(lines): + line = lines[idx] + # Check if this line has async def + if 'async def' in line: + # Replace async def with def + lines[idx] = line.replace('async def', 'def') + print(f" Fixed line {line_number}: {line.strip()[:60]}...") + + return '\n'.join(lines) + +def fix_file(filepath: str, line_numbers: list): + """Fix async issues in a single file""" + full_path = Path(filepath) + if not full_path.exists(): + print(f"⚠️ File not found: {filepath}") + return + + print(f"\n📄 Processing {filepath}") + + # Read the file + with open(full_path, 'r') as f: + content = f.read() + + # Fix each line + for line_num in sorted(line_numbers, reverse=True): + content = remove_async_from_function(content, line_num) + + # Write back + with open(full_path, 'w') as f: + f.write(content) + + print(f"✅ Fixed {len(line_numbers)} async issues") + +def main(): + """Main function to fix all async issues""" + print("🔧 Fixing async/await issues reported by SonarCloud\n") + + total_issues = sum(len(lines) for lines in FILES_TO_FIX.values()) + print(f"Total issues to fix: {total_issues}") + + for filepath, line_numbers in FILES_TO_FIX.items(): + fix_file(filepath, line_numbers) + + print(f"\n✨ Completed fixing {total_issues} async/await issues!") + print("\nNext steps:") + print("1. Review the changes with: git diff") + print("2. Run tests to ensure nothing broke") + print("3. Commit the changes") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 23ad1d2e..84b8b62d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,14 +12,14 @@ pydantic[email]==2.5.0 pydantic-settings==2.1.0 python-jose[cryptography]==3.3.0 passlib[bcrypt]==1.7.4 -python-multipart==0.0.6 +python-multipart==0.0.18 aiofiles==23.2.1 aiosmtplib==3.0.1 -cryptography==41.0.7 +cryptography==44.0.1 pyjwt==2.8.0 argon2-cffi==23.1.0 httpx==0.25.2 -aiohttp==3.9.1 +aiohttp==3.12.14 lxml==4.9.3 xmltodict==0.13.0 python-magic==0.4.27