Skip to content

Commit 253a574

Browse files
committed
Fix CORS security vulnerability
- Restrict CORS origins in PublicVMEnvironment to same host, any port - Restrict CORS origins in LocalIPythonEnvironment to localhost, any port - Use regex patterns to allow notebook origins on different ports - Add allow_all_origins parameter threaded through ResponsibleAIDashboard, Dashboard, and FlaskHelper as an escape hatch for users who need wildcard CORS and understand the security implications - Add docstring for allow_all_origins parameter - Add unit tests for CORS configuration in both environments This fixes CWE-942 (Wildcard CORS) by defaulting to restrictive CORS policy while preserving backward compatibility via opt-in parameter.
1 parent 917a491 commit 253a574

6 files changed

Lines changed: 59 additions & 4 deletions

File tree

rai_core_flask/rai_core_flask/environments/local_ipython_environment.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,16 @@ def __init__(self, service):
3131

3232
def select(self, service):
3333
service.with_credentials = False
34-
service.cors = CORS(service.app)
34+
if service.allow_all_origins:
35+
# User explicitly opted into allowing all origins (less secure)
36+
service.cors = CORS(service.app)
37+
else:
38+
# Default: restrict CORS to localhost origins (any port)
39+
origins = [
40+
r"http://localhost:\d+",
41+
r"https://localhost:\d+",
42+
r"http://127.0.0.1:\d+",
43+
r"https://127.0.0.1:\d+",
44+
]
45+
service.cors = CORS(service.app, origins=origins)
3546
service.env_name = LOCAL

rai_core_flask/rai_core_flask/environments/public_vm_environment.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,12 @@ def __init__(self, service):
3131

3232
def select(self, service):
3333
service.with_credentials = False
34-
service.cors = CORS(service.app)
34+
if service.allow_all_origins:
35+
# User explicitly opted into allowing all origins (less secure)
36+
service.cors = CORS(service.app)
37+
else:
38+
# Default: restrict CORS to the same host over HTTP/HTTPS
39+
# on any port (notebook may run on a different port)
40+
origin_pattern = rf"https?://{service.ip}(:\d+)?"
41+
service.cors = CORS(service.app, origins=[origin_pattern])
3542
service.env_name = PUBLIC_VM

rai_core_flask/rai_core_flask/flask_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ class FlaskHelper(object):
2525
"""FlaskHelper is a class for common Flask utilities used in dashboards."""
2626

2727
def __init__(self, ip=None, port=None, with_credentials=False,
28-
is_private_link=False):
28+
is_private_link=False, allow_all_origins=False):
2929
# The name passed to Flask needs to be unique per instance.
3030
self.app = Flask(uuid.uuid4().hex)
3131

3232
self.port = port
3333
self.ip = ip
3434
self.with_credentials = with_credentials
3535
self.is_private_link = is_private_link
36+
self.allow_all_origins = allow_all_origins
3637
# dictionary to store arbitrary state for use by consuming classes
3738
self.shared_state = {}
3839
if self.ip is None:

rai_core_flask/tests/test_environment_detector.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,30 @@ def test_databricks(self):
6161
assert isinstance(service.env, DatabricksEnvironment)
6262
finally:
6363
del os.environ[DATABRICKS_ENV_VAR]
64+
65+
def test_local_cors_restricted_by_default(self):
66+
service = FlaskHelper()
67+
assert isinstance(service.env, LocalIPythonEnvironment)
68+
assert not service.allow_all_origins
69+
# Verify CORS is configured (not wildcard)
70+
assert hasattr(service, 'cors')
71+
72+
def test_local_cors_wildcard_when_opted_in(self):
73+
service = FlaskHelper(allow_all_origins=True)
74+
assert isinstance(service.env, LocalIPythonEnvironment)
75+
assert service.allow_all_origins
76+
77+
def test_public_vm_cors_restricted_by_default(self, mocker):
78+
mocker.patch('rai_core_flask.FlaskHelper._is_local_port_available',
79+
return_value=True)
80+
service = FlaskHelper(ip="10.0.0.5", with_credentials=False)
81+
assert isinstance(service.env, PublicVMEnvironment)
82+
assert not service.allow_all_origins
83+
84+
def test_public_vm_cors_wildcard_when_opted_in(self, mocker):
85+
mocker.patch('rai_core_flask.FlaskHelper._is_local_port_available',
86+
return_value=True)
87+
service = FlaskHelper(ip="10.0.0.5", with_credentials=False,
88+
allow_all_origins=True)
89+
assert isinstance(service.env, PublicVMEnvironment)
90+
assert service.allow_all_origins

raiwidgets/raiwidgets/dashboard.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(self, *,
6262
locale,
6363
no_inline_dashboard=False,
6464
is_private_link=False,
65+
allow_all_origins=False,
6566
**kwargs):
6667
"""Initialize the dashboard."""
6768

@@ -71,7 +72,8 @@ def __init__(self, *,
7172
try:
7273
self._service = FlaskHelper(ip=public_ip,
7374
port=port,
74-
is_private_link=is_private_link)
75+
is_private_link=is_private_link,
76+
allow_all_origins=allow_all_origins)
7577
except Exception as e:
7678
self._service = None
7779
raise e

raiwidgets/raiwidgets/responsibleai_dashboard.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,16 @@ class ResponsibleAIDashboard(Dashboard):
3232
:param is_private_link: If the dashboard environment is
3333
a private link AML workspace.
3434
:type is_private_link: bool
35+
:param allow_all_origins: If True, allows CORS requests from any
36+
origin. Defaults to False for security. Only set to True if you
37+
understand the security implications (e.g., cross-origin data
38+
exfiltration risks).
39+
:type allow_all_origins: bool
3540
"""
3641
def __init__(self, analysis: RAIInsights,
3742
public_ip=None, port=None, locale=None,
3843
cohort_list=None, is_private_link=False,
44+
allow_all_origins=False,
3945
**kwargs):
4046
self.input = ResponsibleAIDashboardInput(
4147
analysis, cohort_list=cohort_list)
@@ -48,6 +54,7 @@ def __init__(self, analysis: RAIInsights,
4854
locale=locale,
4955
no_inline_dashboard=True,
5056
is_private_link=is_private_link,
57+
allow_all_origins=allow_all_origins,
5158
**kwargs)
5259

5360
def predict():

0 commit comments

Comments
 (0)