Skip to content

Commit 9781bd8

Browse files
committed
Add DagsterWebSocketProxyConsumer and update the routing to support websocket connections see HEA-752
1 parent bbe7a34 commit 9781bd8

6 files changed

Lines changed: 139 additions & 3 deletions

File tree

apps/common/consumers.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import asyncio
2+
import logging
3+
import os
4+
5+
import websockets
6+
from channels.exceptions import DenyConnection
7+
from channels.generic.websocket import AsyncWebsocketConsumer
8+
from django.contrib.auth.models import AnonymousUser
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class DagsterWebSocketProxyConsumer(AsyncWebsocketConsumer):
14+
15+
async def connect(self):
16+
logger.info(f"WebSocket connection attempt: {self.scope['path']}")
17+
18+
# Authentication check
19+
if isinstance(self.scope["user"], AnonymousUser):
20+
logger.error("Authentication required")
21+
raise DenyConnection("Authentication required")
22+
23+
if not self.scope["user"].has_perm("common.access_dagster_ui"):
24+
logger.error(f"User {self.scope['user'].username} lacks permission")
25+
raise DenyConnection("Permission denied")
26+
27+
logger.info(f"User {self.scope['user'].username} authenticated")
28+
29+
# Build upstream URL
30+
dagster_url = os.environ.get("DAGSTER_WEBSERVER_URL", "http://localhost:3000")
31+
dagster_prefix = os.environ.get("DAGSTER_WEBSERVER_PREFIX", "")
32+
33+
path = self.scope["path"]
34+
if path.startswith("/pipelines/"):
35+
path = path[len("/pipelines/") :]
36+
37+
# Convert http to ws
38+
if dagster_url.startswith("https://"):
39+
ws_url = dagster_url.replace("https://", "wss://", 1)
40+
else:
41+
ws_url = dagster_url.replace("http://", "ws://", 1)
42+
43+
# Build target URL
44+
if dagster_prefix:
45+
target_url = f"{ws_url}/{dagster_prefix}/{path}"
46+
else:
47+
target_url = f"{ws_url}/{path}"
48+
49+
# Add query string
50+
if self.scope.get("query_string"):
51+
target_url += f"?{self.scope['query_string'].decode()}"
52+
53+
logger.info(f"Connecting to upstream: {target_url}")
54+
55+
# Get subprotocols from client
56+
subprotocols = self.scope.get("subprotocols", [])
57+
58+
try:
59+
self.websocket = await websockets.connect(
60+
target_url,
61+
max_size=2097152,
62+
ping_interval=20,
63+
subprotocols=subprotocols if subprotocols else None,
64+
)
65+
logger.info("Connected to upstream")
66+
except Exception as e:
67+
logger.error(f"Failed to connect: {e}")
68+
raise DenyConnection(f"Connection to upstream failed: {e}")
69+
70+
await self.accept(self.websocket.subprotocol)
71+
logger.info(f"Client accepted with subprotocol: {self.websocket.subprotocol}")
72+
73+
self.consumer_task = asyncio.create_task(self.consume_from_upstream())
74+
75+
async def disconnect(self, close_code):
76+
logger.info(f"Disconnecting with code {close_code}")
77+
if hasattr(self, "consumer_task"):
78+
self.consumer_task.cancel()
79+
try:
80+
await self.consumer_task
81+
except asyncio.CancelledError:
82+
pass
83+
if hasattr(self, "websocket"):
84+
await self.websocket.close()
85+
86+
async def receive(self, text_data=None, bytes_data=None):
87+
try:
88+
await self.websocket.send(bytes_data or text_data)
89+
except Exception as e:
90+
logger.error(f"Error forwarding to upstream: {e}")
91+
await self.close()
92+
93+
async def consume_from_upstream(self):
94+
try:
95+
async for message in self.websocket:
96+
if isinstance(message, bytes):
97+
await self.send(bytes_data=message)
98+
else:
99+
await self.send(text_data=message)
100+
except asyncio.CancelledError:
101+
pass
102+
except Exception as e:
103+
logger.error(f"Error consuming from upstream: {e}")
104+
await self.close()

apps/common/routing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from django.urls import re_path
2+
3+
from common.consumers import DagsterWebSocketProxyConsumer
4+
5+
websocket_urlpatterns = [
6+
# Route WebSocket connections for Dagster proxy
7+
re_path(r"^pipelines/(?P<path>.*)$", DagsterWebSocketProxyConsumer.as_asgi()),
8+
]

docker/app/run_django.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ touch log/django_sql.log
4040
chown -R django:django log/*
4141

4242
echo Starting Gunicorn with DJANGO_SETTINGS_MODULE=${DJANGO_SETTINGS_MODULE}
43-
gosu django gunicorn ${APP}.wsgi:application \
43+
gosu django gunicorn ${APP}.asgi:application \
4444
--name ${APP}${ENV} \
45+
--worker-class uvicorn.workers.UvicornWorker \
4546
--config $(dirname $(readlink -f "$0"))/gunicorn_config.py \
46-
$* 2>&1
47+
$* 2>&1

hea/asgi.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,25 @@
99

1010
import os
1111

12+
from channels.auth import AuthMiddlewareStack
13+
from channels.routing import ProtocolTypeRouter, URLRouter
14+
from channels.security.websocket import AllowedHostsOriginValidator
1215
from django.core.asgi import get_asgi_application
1316

1417
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "hea.settings")
1518

16-
application = get_asgi_application()
19+
django_asgi_app = get_asgi_application()
20+
21+
22+
django_asgi_app = get_asgi_application()
23+
24+
# Import routing after Django setup
25+
from common.routing import websocket_urlpatterns # noqa: E402
26+
27+
application = ProtocolTypeRouter(
28+
{
29+
"http": django_asgi_app,
30+
# WebSocket requests handled by Channels consumers
31+
"websocket": AllowedHostsOriginValidator(AuthMiddlewareStack(URLRouter(websocket_urlpatterns))),
32+
}
33+
)

hea/settings/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
"rest_framework_gis",
110110
"revproxy",
111111
"corsheaders",
112+
"channels",
112113
]
113114
PROJECT_APPS = ["common", "metadata", "baseline"]
114115
INSTALLED_APPS = EXTERNAL_APPS + PROJECT_APPS
@@ -155,6 +156,9 @@
155156
"SEARCH_PARAM": "search",
156157
}
157158

159+
ASGI_APPLICATION = "hea.asgi.application"
160+
CHANNEL_LAYERS = {"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}}
161+
WEBSOCKET_ACCEPT_ALL = False # Require authentication
158162

159163
########## CORS CONFIGURATION
160164
# See: https://github.com/ottoyiu/django-cors-headers

requirements/base.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Do not pip freeze into this file. Add only the packages you need so pip can align dependencies.
2+
channels==4.3.1
23
dagster==1.11.7
34
dagster_aws==0.27.7
45
dagster-pipes==1.11.7
@@ -46,6 +47,7 @@ sqlparse==0.5.0
4647
tabulate==0.9.0
4748
# Need universal-pathlib > 0.2.0 for gdrivefs support
4849
universal-pathlib==0.2.1
50+
uvicorn==0.37.0
4951
whitenoise==6.4.0
5052
xlrd==2.0.1
5153
xlutils==2.0.0

0 commit comments

Comments
 (0)