Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions integration_tests/cors_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
Standalone Robyn app with ALLOW_CORS enabled, used by CORS integration tests.
Runs on a separate port (8083) so it doesn't conflict with base_routes.py.
"""

import os

from robyn import ALLOW_CORS, Robyn

app = Robyn(__file__)

ALLOWED_ORIGINS = ["http://localhost:3000", "https://frontend.example.com"]
ALLOW_CORS(app, origins=ALLOWED_ORIGINS)


@app.get("/")
def index():
return "OK"


@app.post("/data")
def post_data(request):
return "created"


@app.get("/custom-header")
def custom_header(request):
from robyn import Response

return Response(
status_code=200,
headers={"x-custom": "hello"},
description="custom",
)


if __name__ == "__main__":
port = int(os.getenv("ROBYN_PORT", "8083"))
app.start(port=port, _check_port=False)
175 changes: 175 additions & 0 deletions integration_tests/test_cors_preflight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
Integration tests for CORS preflight (OPTIONS) handling.

Regression test for https://github.com/sparckles/robyn/issues/1346
These tests spin up a real Robyn server with ALLOW_CORS enabled and make
actual HTTP requests — no TestClient.
"""

import os
import pathlib
import signal
import socket
import subprocess
import time

import pytest
import requests

CORS_PORT = 8083
CORS_HOST = "127.0.0.1"
CORS_BASE_URL = f"http://{CORS_HOST}:{CORS_PORT}"
REQUEST_TIMEOUT = 5

ALLOWED_ORIGIN = "http://localhost:3000"
DISALLOWED_ORIGIN = "http://evil.example.com"


def _start_cors_server():
app_path = os.path.join(pathlib.Path(__file__).parent.resolve(), "cors_app.py")
env = os.environ.copy()
env["ROBYN_HOST"] = CORS_HOST
env["ROBYN_PORT"] = str(CORS_PORT)

process = subprocess.Popen(
["python3", app_path],
env=env,
preexec_fn=os.setsid,
)

timeout = 15
start = time.time()
while True:
if process.poll() is not None:
raise RuntimeError(f"CORS server exited early with code {process.returncode}")
if time.time() - start > timeout:
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
raise ConnectionError(f"CORS server didn't start on {CORS_HOST}:{CORS_PORT}")
try:
sock = socket.create_connection((CORS_HOST, CORS_PORT), timeout=2)
sock.close()
break
except Exception:
time.sleep(0.5)

time.sleep(1)
return process
Comment thread
sansyrox marked this conversation as resolved.


@pytest.fixture(scope="module")
def cors_server():
process = _start_cors_server()
yield
try:
os.killpg(os.getpgid(process.pid), signal.SIGKILL)
except ProcessLookupError:
pass


# ---------------------------------------------------------------------------
# Preflight (OPTIONS) tests
# ---------------------------------------------------------------------------


def test_options_preflight_returns_204_with_cors_headers(cors_server):
"""Browser sends OPTIONS preflight; server must return 204 with all CORS headers."""
resp = requests.options(
f"{CORS_BASE_URL}/data",
headers={
"Origin": ALLOWED_ORIGIN,
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Content-Type, Authorization",
},
timeout=REQUEST_TIMEOUT,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
assert resp.status_code == 204
assert resp.headers["Access-Control-Allow-Origin"] == ALLOWED_ORIGIN
assert "POST" in resp.headers["Access-Control-Allow-Methods"]
assert resp.headers["Access-Control-Allow-Credentials"] == "true"


def test_options_preflight_no_duplicate_allow_origin(cors_server):
"""
Regression: ensure Access-Control-Allow-Origin appears exactly once.
Duplicate values cause browsers to reject the preflight (Fetch spec).
"""
resp = requests.options(
f"{CORS_BASE_URL}/",
headers={
"Origin": ALLOWED_ORIGIN,
"Access-Control-Request-Method": "GET",
},
timeout=REQUEST_TIMEOUT,
)
raw_headers = resp.raw.headers if hasattr(resp.raw, "headers") else resp.headers

origin_values = raw_headers.getall("Access-Control-Allow-Origin", None) if hasattr(raw_headers, "getall") else None

if origin_values is not None:
assert len(origin_values) == 1, f"Access-Control-Allow-Origin appeared {len(origin_values)} times: {origin_values}"

assert resp.headers["Access-Control-Allow-Origin"] == ALLOWED_ORIGIN
Comment thread
sansyrox marked this conversation as resolved.


def test_options_preflight_disallowed_origin_returns_403(cors_server):
"""Origins not in the allowed list should be rejected."""
resp = requests.options(
f"{CORS_BASE_URL}/data",
headers={
"Origin": DISALLOWED_ORIGIN,
"Access-Control-Request-Method": "POST",
},
timeout=REQUEST_TIMEOUT,
)
assert resp.status_code == 403


# ---------------------------------------------------------------------------
# Normal request tests (non-preflight)
# ---------------------------------------------------------------------------


def test_get_with_allowed_origin_has_cors_headers(cors_server):
"""Normal GET from an allowed origin should carry CORS response headers."""
resp = requests.get(
f"{CORS_BASE_URL}/",
headers={"Origin": ALLOWED_ORIGIN},
timeout=REQUEST_TIMEOUT,
)
assert resp.status_code == 200
allow_origin = resp.headers.get("Access-Control-Allow-Origin")
assert allow_origin is not None
assert "Access-Control-Allow-Methods" in resp.headers


def test_get_without_origin_still_has_global_cors_headers(cors_server):
"""Requests without Origin (e.g. curl, Postman) should still get global CORS headers."""
resp = requests.get(f"{CORS_BASE_URL}/", timeout=REQUEST_TIMEOUT)
assert resp.status_code == 200
assert "Access-Control-Allow-Methods" in resp.headers


def test_post_with_allowed_origin(cors_server):
"""POST from allowed origin should succeed with CORS headers."""
resp = requests.post(
f"{CORS_BASE_URL}/data",
headers={
"Origin": ALLOWED_ORIGIN,
"Content-Type": "application/json",
},
data="{}",
timeout=REQUEST_TIMEOUT,
)
assert resp.status_code == 200


def test_custom_response_headers_not_clobbered_by_globals(cors_server):
"""Route-level headers set by the handler should not be overwritten by globals."""
resp = requests.get(
f"{CORS_BASE_URL}/custom-header",
headers={"Origin": ALLOWED_ORIGIN},
timeout=REQUEST_TIMEOUT,
)
assert resp.status_code == 200
assert resp.headers.get("x-custom") == "hello"
assert "Access-Control-Allow-Methods" in resp.headers
2 changes: 1 addition & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ async fn index(
.is_some_and(|paths| paths.contains(&request.url.path));

if !is_excluded {
response.headers_mut().extend(&global_response_headers);
response.headers_mut().set_missing(&global_response_headers);
}

// After middleware
Expand Down
10 changes: 10 additions & 0 deletions src/types/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ impl Headers {
}
}

/// Merge headers from `headers` into `self`, but only for keys not already present.
/// This gives middleware-set headers precedence over global defaults,
/// preventing duplicate `Access-Control-Allow-Origin` (and similar) violations.
pub fn set_missing(&mut self, headers: &Headers) {
for iter in headers.headers.iter() {
let (key, values) = iter.pair();
self.headers.entry(key.clone()).or_insert_with(|| values.clone());
}
}

pub fn from_actix_headers(req_headers: &HeaderMap) -> Self {
let headers = Headers::default();

Expand Down
Loading