Skip to content

Commit bc5c117

Browse files
committed
Improving security
1 parent e30aaad commit bc5c117

3 files changed

Lines changed: 427 additions & 19 deletions

File tree

app/main.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,61 @@
1-
from fastapi import FastAPI, status
1+
from fastapi import FastAPI, status, Request
22
from fastapi.exceptions import RequestValidationError
33
from fastapi.encoders import jsonable_encoder
4+
from fastapi.middleware.cors import CORSMiddleware
45
from starlette.responses import JSONResponse
6+
from starlette.middleware.base import BaseHTTPMiddleware
57
from app.routes.compile import compile_endpoint
8+
from app.process_monitor import process_monitor
69

710

811
app = FastAPI()
912

1013

14+
# Start the process monitor when app starts
15+
@app.on_event("startup")
16+
async def startup_event():
17+
process_monitor.start()
18+
print("Process monitor started - will kill compilation processes older than 8 seconds")
19+
20+
21+
@app.on_event("shutdown")
22+
async def shutdown_event():
23+
process_monitor.stop()
24+
25+
26+
# Security Headers Middleware
27+
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
28+
async def dispatch(self, request: Request, call_next):
29+
response = await call_next(request)
30+
31+
# Add security headers
32+
response.headers["X-Content-Type-Options"] = "nosniff"
33+
response.headers["X-Frame-Options"] = "DENY"
34+
response.headers["X-XSS-Protection"] = "1; mode=block"
35+
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
36+
37+
# Only add HSTS for HTTPS connections
38+
if request.url.scheme == "https":
39+
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
40+
41+
return response
42+
43+
44+
# Add security headers middleware
45+
app.add_middleware(SecurityHeadersMiddleware)
46+
47+
# Configure CORS - permissive for compatibility with Hasura
48+
# You can restrict allow_origins later if needed
49+
app.add_middleware(
50+
CORSMiddleware,
51+
allow_origins=["*"], # Allow all origins for now - restrict this in production
52+
allow_credentials=True,
53+
allow_methods=["*"], # Allow all methods
54+
allow_headers=["*"], # Hasura sends various headers
55+
max_age=3600,
56+
)
57+
58+
1159
@app.exception_handler(RequestValidationError)
1260
async def validation_exception_handler(exc: RequestValidationError):
1361
return JSONResponse(
@@ -16,4 +64,10 @@ async def validation_exception_handler(exc: RequestValidationError):
1664
)
1765

1866

67+
# Health check endpoint for monitoring
68+
@app.get("/health")
69+
async def health_check():
70+
return {"status": "healthy", "service": "zxbasic-compiler"}
71+
72+
1973
app.include_router(compile_endpoint, prefix='/compile')

app/process_monitor.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
Background thread that monitors and kills long-running compilation processes.
3+
Runs as part of the FastAPI application.
4+
"""
5+
import os
6+
import time
7+
import signal
8+
import threading
9+
import subprocess
10+
from datetime import datetime
11+
12+
# Configuration
13+
MAX_PROCESS_AGE = 8 # Kill compilation processes older than 8 seconds
14+
CHECK_INTERVAL = 2 # Check every 2 seconds
15+
16+
17+
class ProcessMonitor:
18+
def __init__(self):
19+
self.running = False
20+
self.thread = None
21+
self.processes = {} # Track subprocess PIDs we start
22+
23+
def register_process(self, pid):
24+
"""Register a compilation process to monitor"""
25+
self.processes[pid] = time.time()
26+
print(f"[MONITOR] Tracking compilation process PID {pid}")
27+
28+
def start(self):
29+
"""Start the monitor thread"""
30+
if not self.running:
31+
self.running = True
32+
self.thread = threading.Thread(target=self._monitor_loop, daemon=True)
33+
self.thread.start()
34+
print("[MONITOR] Process monitor started")
35+
36+
def stop(self):
37+
"""Stop the monitor thread"""
38+
self.running = False
39+
40+
def _monitor_loop(self):
41+
"""Main monitoring loop that runs in background"""
42+
while self.running:
43+
try:
44+
current_time = time.time()
45+
pids_to_remove = []
46+
47+
# Check tracked processes
48+
for pid, start_time in list(self.processes.items()):
49+
age = current_time - start_time
50+
51+
try:
52+
# Check if process still exists
53+
os.kill(pid, 0) # Signal 0 just checks if process exists
54+
55+
if age > MAX_PROCESS_AGE:
56+
print(f"[MONITOR] Killing stuck process PID {pid} (age: {age:.1f}s)")
57+
try:
58+
# Try graceful termination first
59+
os.kill(pid, signal.SIGTERM)
60+
time.sleep(0.5)
61+
# Check if still alive
62+
os.kill(pid, 0)
63+
# If still alive, force kill
64+
os.kill(pid, signal.SIGKILL)
65+
print(f"[MONITOR] Force killed PID {pid}")
66+
except OSError:
67+
pass # Process already dead
68+
pids_to_remove.append(pid)
69+
70+
except OSError:
71+
# Process doesn't exist anymore
72+
pids_to_remove.append(pid)
73+
74+
# Clean up dead processes from tracking
75+
for pid in pids_to_remove:
76+
del self.processes[pid]
77+
78+
# Also check for any zxbc processes we didn't start
79+
# (in case of threading issues)
80+
self._check_orphan_processes()
81+
82+
except Exception as e:
83+
print(f"[MONITOR] Error in monitor loop: {e}")
84+
85+
time.sleep(CHECK_INTERVAL)
86+
87+
def _check_orphan_processes(self):
88+
"""Check for compilation processes we might not be tracking"""
89+
try:
90+
# Use ps to find zxbc processes
91+
result = subprocess.run(
92+
["ps", "aux"],
93+
capture_output=True,
94+
text=True,
95+
timeout=1
96+
)
97+
98+
for line in result.stdout.split('\n'):
99+
if 'zxbc' in line and '-taB' in line:
100+
parts = line.split()
101+
if len(parts) > 1:
102+
pid = int(parts[1])
103+
# If we're not tracking this PID, it might be orphaned
104+
if pid not in self.processes:
105+
# Check process age using /proc if available
106+
try:
107+
stat_path = f"/proc/{pid}/stat"
108+
if os.path.exists(stat_path):
109+
with open(stat_path, 'r') as f:
110+
stat_data = f.read().split()
111+
# Field 21 is start time in jiffies
112+
start_jiffies = int(stat_data[21])
113+
# Rough age calculation
114+
age_seconds = (time.time() - os.path.getmtime(stat_path))
115+
if age_seconds > MAX_PROCESS_AGE:
116+
print(f"[MONITOR] Found orphan zxbc process PID {pid}, killing...")
117+
os.kill(pid, signal.SIGKILL)
118+
except:
119+
pass
120+
except:
121+
pass # PS might not be available or fail
122+
123+
124+
# Global monitor instance
125+
process_monitor = ProcessMonitor()

0 commit comments

Comments
 (0)