|
8 | 8 | import logging |
9 | 9 |
|
10 | 10 | import aiohttp |
11 | | -from fastapi import APIRouter, Depends, Request, Response |
| 11 | +from fastapi import APIRouter, Depends, Request, Response, WebSocket |
12 | 12 | from fastapi.responses import JSONResponse, StreamingResponse |
13 | 13 | from starlette.background import BackgroundTask |
14 | 14 |
|
15 | 15 | from open_webui.utils.auth import get_verified_user |
16 | 16 | from open_webui.utils.access_control import has_connection_access |
17 | 17 | from open_webui.models.groups import Groups |
| 18 | +from open_webui.models.users import Users |
18 | 19 |
|
19 | 20 | log = logging.getLogger(__name__) |
20 | 21 |
|
@@ -149,3 +150,155 @@ async def cleanup(): |
149 | 150 | return JSONResponse( |
150 | 151 | {"error": f"Terminal proxy error: {error}"}, status_code=502 |
151 | 152 | ) |
| 153 | + |
| 154 | + |
| 155 | +# --------------------------------------------------------------------------- |
| 156 | +# WebSocket proxy for interactive terminal sessions |
| 157 | +# --------------------------------------------------------------------------- |
| 158 | + |
| 159 | + |
| 160 | +async def _resolve_authenticated_connection(ws: WebSocket, server_id: str): |
| 161 | + """Authenticate a WebSocket via first-message auth and resolve the terminal server. |
| 162 | +
|
| 163 | + The client must send ``{"type": "auth", "token": "<jwt>"}`` as its first |
| 164 | + message after connecting. |
| 165 | +
|
| 166 | + Returns ``(user, connection)`` on success, or ``None`` after closing *ws* |
| 167 | + with an appropriate error code. |
| 168 | + """ |
| 169 | + import asyncio |
| 170 | + import json |
| 171 | + from open_webui.utils.auth import decode_token |
| 172 | + |
| 173 | + # First-message authentication |
| 174 | + try: |
| 175 | + raw = await asyncio.wait_for(ws.receive_text(), timeout=10.0) |
| 176 | + payload = json.loads(raw) |
| 177 | + if payload.get("type") != "auth": |
| 178 | + await ws.close(code=4001, reason="Expected auth message") |
| 179 | + return None |
| 180 | + token = payload.get("token", "") |
| 181 | + data = decode_token(token) |
| 182 | + if data is None or "id" not in data: |
| 183 | + await ws.close(code=4001, reason="Invalid token") |
| 184 | + return None |
| 185 | + user = Users.get_user_by_id(data["id"]) |
| 186 | + if user is None: |
| 187 | + await ws.close(code=4001, reason="User not found") |
| 188 | + return None |
| 189 | + except (asyncio.TimeoutError, json.JSONDecodeError): |
| 190 | + await ws.close(code=4001, reason="Auth timeout or invalid payload") |
| 191 | + return None |
| 192 | + except Exception: |
| 193 | + await ws.close(code=4001, reason="Invalid token") |
| 194 | + return None |
| 195 | + |
| 196 | + # Resolve terminal server |
| 197 | + connections = ws.app.state.config.TERMINAL_SERVER_CONNECTIONS or [] |
| 198 | + connection = next((c for c in connections if c.get("id") == server_id), None) |
| 199 | + |
| 200 | + if connection is None: |
| 201 | + await ws.close(code=4004, reason="Terminal server not found") |
| 202 | + return None |
| 203 | + |
| 204 | + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} |
| 205 | + if not has_connection_access(user, connection, user_group_ids): |
| 206 | + await ws.close(code=4003, reason="Access denied") |
| 207 | + return None |
| 208 | + |
| 209 | + return user, connection |
| 210 | + |
| 211 | + |
| 212 | +@router.websocket("/{server_id}/api/terminals/{session_id}") |
| 213 | +async def ws_terminal( |
| 214 | + ws: WebSocket, |
| 215 | + server_id: str, |
| 216 | + session_id: str, |
| 217 | +): |
| 218 | + """Proxy an interactive WebSocket terminal session to a terminal server. |
| 219 | +
|
| 220 | + Uses first-message auth: the client sends ``{"type": "auth", "token": "<jwt>"}`` |
| 221 | + as its first message. The proxy validates the JWT, then connects to the |
| 222 | + upstream terminal server and authenticates with the server's API key. |
| 223 | + """ |
| 224 | + await ws.accept() |
| 225 | + |
| 226 | + result = await _resolve_authenticated_connection(ws, server_id) |
| 227 | + if result is None: |
| 228 | + return |
| 229 | + user, connection = result |
| 230 | + |
| 231 | + base_url = (connection.get("url") or "").rstrip("/") |
| 232 | + if not base_url: |
| 233 | + await ws.close(code=4003, reason="Terminal server URL not configured") |
| 234 | + return |
| 235 | + |
| 236 | + # Build upstream WebSocket URL (no token in URL) |
| 237 | + ws_base = base_url.replace("https://", "wss://").replace("http://", "ws://") |
| 238 | + |
| 239 | + auth_type = connection.get("auth_type", "bearer") |
| 240 | + upstream_params = {} |
| 241 | + # For orchestrator-backed servers, pass user_id |
| 242 | + upstream_params["user_id"] = user.id |
| 243 | + |
| 244 | + import urllib.parse |
| 245 | + |
| 246 | + upstream_url = f"{ws_base}/api/terminals/{session_id}" |
| 247 | + if upstream_params: |
| 248 | + upstream_url += f"?{urllib.parse.urlencode(upstream_params)}" |
| 249 | + |
| 250 | + session = aiohttp.ClientSession() |
| 251 | + try: |
| 252 | + async with session.ws_connect(upstream_url) as upstream: |
| 253 | + import asyncio |
| 254 | + import json as _json |
| 255 | + |
| 256 | + # First-message auth to upstream terminal server |
| 257 | + auth_type = connection.get("auth_type", "bearer") |
| 258 | + if auth_type == "bearer": |
| 259 | + key = connection.get("key", "") |
| 260 | + await upstream.send_str(_json.dumps({"type": "auth", "token": key})) |
| 261 | + |
| 262 | + async def _client_to_upstream(): |
| 263 | + """Forward client → upstream.""" |
| 264 | + try: |
| 265 | + while True: |
| 266 | + msg = await ws.receive() |
| 267 | + if msg["type"] == "websocket.disconnect": |
| 268 | + break |
| 269 | + elif "bytes" in msg and msg["bytes"]: |
| 270 | + await upstream.send_bytes(msg["bytes"]) |
| 271 | + elif "text" in msg and msg["text"]: |
| 272 | + await upstream.send_str(msg["text"]) |
| 273 | + except Exception: |
| 274 | + pass |
| 275 | + |
| 276 | + async def _upstream_to_client(): |
| 277 | + """Forward upstream → client.""" |
| 278 | + try: |
| 279 | + async for msg in upstream: |
| 280 | + if msg.type == aiohttp.WSMsgType.BINARY: |
| 281 | + await ws.send_bytes(msg.data) |
| 282 | + elif msg.type == aiohttp.WSMsgType.TEXT: |
| 283 | + await ws.send_text(msg.data) |
| 284 | + elif msg.type in ( |
| 285 | + aiohttp.WSMsgType.CLOSE, |
| 286 | + aiohttp.WSMsgType.ERROR, |
| 287 | + ): |
| 288 | + break |
| 289 | + except Exception: |
| 290 | + pass |
| 291 | + |
| 292 | + await asyncio.gather( |
| 293 | + _client_to_upstream(), |
| 294 | + _upstream_to_client(), |
| 295 | + return_exceptions=True, |
| 296 | + ) |
| 297 | + except Exception as e: |
| 298 | + log.exception("Terminal WebSocket proxy error: %s", e) |
| 299 | + finally: |
| 300 | + await session.close() |
| 301 | + try: |
| 302 | + await ws.close() |
| 303 | + except Exception: |
| 304 | + pass |
0 commit comments