Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,16 @@ def __init__(self, service):

def select(self, service):
service.with_credentials = False
service.cors = CORS(service.app)
if service.allow_all_origins:
# User explicitly opted into allowing all origins (less secure)
service.cors = CORS(service.app)
else:
# Default: restrict CORS to localhost origins (any port)
origins = [
r"http://localhost:\d+",
r"https://localhost:\d+",
r"http://127.0.0.1:\d+",
r"https://127.0.0.1:\d+",
]
service.cors = CORS(service.app, origins=origins)
service.env_name = LOCAL
Comment thread
imatiach-msft marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,12 @@ def __init__(self, service):

def select(self, service):
service.with_credentials = False
service.cors = CORS(service.app)
if service.allow_all_origins:
# User explicitly opted into allowing all origins (less secure)
service.cors = CORS(service.app)
else:
# Default: restrict CORS to the same host over HTTP/HTTPS
# on any port (notebook may run on a different port)
origin_pattern = rf"https?://{service.ip}(:\d+)?"
service.cors = CORS(service.app, origins=[origin_pattern])
service.env_name = PUBLIC_VM
Comment thread
imatiach-msft marked this conversation as resolved.
3 changes: 2 additions & 1 deletion rai_core_flask/rai_core_flask/flask_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ class FlaskHelper(object):
"""FlaskHelper is a class for common Flask utilities used in dashboards."""

def __init__(self, ip=None, port=None, with_credentials=False,
is_private_link=False):
is_private_link=False, allow_all_origins=False):
# The name passed to Flask needs to be unique per instance.
self.app = Flask(uuid.uuid4().hex)

self.port = port
self.ip = ip
self.with_credentials = with_credentials
self.is_private_link = is_private_link
self.allow_all_origins = allow_all_origins
Comment thread
imatiach-msft marked this conversation as resolved.
# dictionary to store arbitrary state for use by consuming classes
self.shared_state = {}
if self.ip is None:
Expand Down
27 changes: 27 additions & 0 deletions rai_core_flask/tests/test_environment_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,30 @@ def test_databricks(self):
assert isinstance(service.env, DatabricksEnvironment)
finally:
del os.environ[DATABRICKS_ENV_VAR]

def test_local_cors_restricted_by_default(self):
service = FlaskHelper()
assert isinstance(service.env, LocalIPythonEnvironment)
assert not service.allow_all_origins
# Verify CORS is configured (not wildcard)
assert hasattr(service, 'cors')

def test_local_cors_wildcard_when_opted_in(self):
service = FlaskHelper(allow_all_origins=True)
assert isinstance(service.env, LocalIPythonEnvironment)
assert service.allow_all_origins

def test_public_vm_cors_restricted_by_default(self, mocker):
mocker.patch('rai_core_flask.FlaskHelper._is_local_port_available',
return_value=True)
service = FlaskHelper(ip="10.0.0.5", with_credentials=False)
assert isinstance(service.env, PublicVMEnvironment)
assert not service.allow_all_origins

def test_public_vm_cors_wildcard_when_opted_in(self, mocker):
mocker.patch('rai_core_flask.FlaskHelper._is_local_port_available',
return_value=True)
service = FlaskHelper(ip="10.0.0.5", with_credentials=False,
allow_all_origins=True)
assert isinstance(service.env, PublicVMEnvironment)
assert service.allow_all_origins
4 changes: 3 additions & 1 deletion raiwidgets/raiwidgets/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self, *,
locale,
no_inline_dashboard=False,
is_private_link=False,
allow_all_origins=False,
**kwargs):
"""Initialize the dashboard."""

Expand All @@ -71,7 +72,8 @@ def __init__(self, *,
try:
self._service = FlaskHelper(ip=public_ip,
port=port,
is_private_link=is_private_link)
is_private_link=is_private_link,
allow_all_origins=allow_all_origins)
except Exception as e:
self._service = None
raise e
Expand Down
7 changes: 7 additions & 0 deletions raiwidgets/raiwidgets/responsibleai_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@ class ResponsibleAIDashboard(Dashboard):
:param is_private_link: If the dashboard environment is
a private link AML workspace.
:type is_private_link: bool
:param allow_all_origins: If True, allows CORS requests from any
origin. Defaults to False for security. Only set to True if you
understand the security implications (e.g., cross-origin data
exfiltration risks).
:type allow_all_origins: bool
"""
def __init__(self, analysis: RAIInsights,
public_ip=None, port=None, locale=None,
cohort_list=None, is_private_link=False,
allow_all_origins=False,
**kwargs):
self.input = ResponsibleAIDashboardInput(
analysis, cohort_list=cohort_list)
Expand All @@ -48,6 +54,7 @@ def __init__(self, analysis: RAIInsights,
locale=locale,
no_inline_dashboard=True,
is_private_link=is_private_link,
allow_all_origins=allow_all_origins,
**kwargs)

def predict():
Expand Down
Loading