Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ module containerAppBackend 'br/public:avm/res/app/container-app:0.19.0' = {
}
]
ingressTargetPort: 8000
ingressExternal: true
ingressExternal: !enablePrivateNetworking
scaleSettings: {
Comment on lines 1112 to 1116
// maxReplicas: enableScalability ? 3 : 1
maxReplicas: 1 // maxReplicas set to 1 (not 3) due to multiple agents created per type during WAF deployment
Expand Down
15 changes: 0 additions & 15 deletions infra/main.parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,6 @@
},
"vmAdminPassword": {
"value": "${AZURE_ENV_VM_ADMIN_PASSWORD}"
},
"aiModelDeployments": {
"value": [
{
"name": "${AZURE_ENV_GPT_MODEL_NAME}",
"model": {
"name": "${AZURE_ENV_GPT_MODEL_NAME}",
"version": "${AZURE_ENV_GPT_MODEL_VERSION}"
},
"sku": {
"name": "${AZURE_ENV_MODEL_DEPLOYMENT_TYPE}",
"capacity": "${AZURE_ENV_GPT_MODEL_CAPACITY}"
}
}
]
}
}
}
2 changes: 1 addition & 1 deletion infra/main_custom.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ module containerAppBackend 'br/public:avm/res/app/container-app:0.19.0' = {
}
]
ingressTargetPort: 8000
ingressExternal: true
ingressExternal: !enablePrivateNetworking
scaleSettings: {
// maxReplicas: enableScalability ? 3 : 1
maxReplicas: 1 // maxReplicas set to 1 (not 3) due to multiple agents created per type during WAF deployment
Expand Down
1 change: 1 addition & 0 deletions src/backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def create_app() -> FastAPI:
# app.include_router(agents_router, prefix="/api/agents", tags=["agents"])

@app.get("/health")
@app.get("/api/health")
async def health_check():
"""Health check endpoint."""
return {"status": "healthy"}
Expand Down
115 changes: 112 additions & 3 deletions src/frontend/frontend_server.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import asyncio
import os

import httpx
import uvicorn
import websockets
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.responses import FileResponse, Response
from fastapi.staticfiles import StaticFiles

# Load environment variables from .env file
load_dotenv()

# Internal backend URL used by the server-side proxy.
# The browser never contacts this URL directly.
BACKEND_API_URL = os.getenv("API_URL", "http://localhost:8000").rstrip("/")

app = FastAPI()

app.add_middleware(
Expand Down Expand Up @@ -38,7 +45,11 @@ async def serve_index():
@app.get("/config")
async def get_config():
config = {
"API_URL": os.getenv("API_URL", "API_URL not set"),
# Return empty string so the browser uses relative /api/* paths
# which are proxied server-side to BACKEND_API_URL. This ensures
# backend Container Apps with internal-only ingress are never
# contacted directly from the browser.
"API_URL": "",
Comment thread
Ashwal-Microsoft marked this conversation as resolved.
Outdated
"REACT_APP_MSAL_AUTH_CLIENTID": os.getenv(
"REACT_APP_MSAL_AUTH_CLIENTID", "Client ID not set"
),
Expand All @@ -56,6 +67,104 @@ async def get_config():
return config


# ---------------------------------------------------------------------------
# Reverse proxy: WebSocket (must be declared before the HTTP catch-all below)
# ---------------------------------------------------------------------------

@app.websocket("/api/socket/{batch_id}")
async def proxy_websocket(websocket: WebSocket, batch_id: str):
"""Proxy WebSocket connections from the browser to the internal backend."""
await websocket.accept()

backend_ws_url = (
BACKEND_API_URL
.replace("https://", "wss://")
.replace("http://", "ws://")
)
backend_ws_url = f"{backend_ws_url}/api/socket/{batch_id}"

try:
async with websockets.connect(backend_ws_url) as backend_ws:

async def forward_to_backend():
try:
while True:
data = await websocket.receive_text()
await backend_ws.send(data)
except (WebSocketDisconnect, Exception):
pass

async def forward_to_client():
try:
async for message in backend_ws:
await websocket.send_text(message)
except (WebSocketDisconnect, Exception):
pass

await asyncio.gather(forward_to_backend(), forward_to_client())
Comment thread
Ashwal-Microsoft marked this conversation as resolved.
Outdated
except Exception:
pass
finally:
Comment thread
Ashwal-Microsoft marked this conversation as resolved.
Outdated
try:
await websocket.close()
except Exception:
pass
Comment on lines +101 to +144


# ---------------------------------------------------------------------------
# Reverse proxy: HTTP (all /api/* routes proxied to the internal backend)
# ---------------------------------------------------------------------------

_PROXY_CLIENT = httpx.AsyncClient(timeout=300.0)
Comment thread
Ashwal-Microsoft marked this conversation as resolved.
Outdated


@app.api_route(
"/api/{path:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"],
)
async def proxy_api(request: Request, path: str):
Comment thread
Ashwal-Microsoft marked this conversation as resolved.
Outdated
"""Proxy HTTP API requests from the browser to the internal backend."""
target_url = f"{BACKEND_API_URL}/api/{path}"
if request.url.query:
target_url = f"{target_url}?{request.url.query}"

# Forward all headers except 'host' (would confuse the backend)
headers = {
k: v for k, v in request.headers.items()
if k.lower() != "host"
}

body = await request.body()

response = await _PROXY_CLIENT.request(
method=request.method,
url=target_url,
headers=headers,
content=body,
)
Comment on lines +168 to +175

# Strip hop-by-hop headers that must not be forwarded
excluded_headers = {
"content-encoding", "transfer-encoding", "connection",
"keep-alive", "proxy-authenticate", "proxy-authorization",
"te", "trailers", "upgrade",
}
forwarded_headers = {
k: v for k, v in response.headers.items()
if k.lower() not in excluded_headers
}
Comment thread
Ashwal-Microsoft marked this conversation as resolved.
Outdated

return Response(
content=response.content,
status_code=response.status_code,
headers=forwarded_headers,
)
Comment thread
Ashwal-Microsoft marked this conversation as resolved.
Outdated
Comment on lines +170 to +193


# ---------------------------------------------------------------------------
# SPA catch-all (must be last)
# ---------------------------------------------------------------------------

@app.get("/{full_path:path}")
async def serve_app(full_path: str):
# Remediation: normalize and check containment before serving
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ uvicorn[standard]
jinja2
azure-identity
python-dotenv
python-multipart
python-multipart
httpx
websockets
129 changes: 86 additions & 43 deletions src/frontend/src/api/WebSocketService.tsx
Original file line number Diff line number Diff line change
@@ -1,64 +1,107 @@
import { getApiUrl } from '../api/config';
import { getApiUrl, headerBuilder } from '../api/config';

// WebSocketService.ts
// Polling-based status stream service that preserves the existing event interface.
type EventHandler = (data: any) => void;

class WebSocketService {
private socket: WebSocket | null = null;
private pollInterval: ReturnType<typeof setInterval> | null = null;
private isConnected = false;
private activeBatchId: string | null = null;
private lastKnownStatus: Record<string, string> = {};
private eventHandlers: Record<string, EventHandler[]> = {};

connect(batch_id: string): void {
let apiUrl = getApiUrl();
console.log('API URL: websocket', apiUrl);
if (apiUrl) {
apiUrl = apiUrl.replace(/^https?/, match => match === "https" ? "wss" : "ws");
} else {
throw new Error('API URL is null');
private async pollBatchSummary(batchId: string): Promise<void> {
const apiUrl = getApiUrl();
if (!apiUrl) {
this._emit('error', new Error('API URL is null'));
return;
}
console.log('Connecting to WebSocket:', apiUrl);
if (this.socket) return; // Prevent duplicate connections
this.socket = new WebSocket(`${apiUrl}/socket/${batch_id}`);

this.socket.onopen = () => {
console.log('WebSocket connection opened.');
this._emit('open', undefined);
};

this.socket.onmessage = (event: MessageEvent) => {
try {
const data = JSON.parse(event.data);
this._emit('message', data);
} catch (err) {
console.error('Error parsing message:', err);

try {
const response = await fetch(`${apiUrl}/batch-summary/${batchId}`, {
headers: headerBuilder({}),
});

if (!response.ok) {
throw new Error(`Failed to fetch batch status: ${response.status}`);
}
};

this.socket.onerror = (error: Event) => {
console.error('WebSocket error:', error);
const payload = await response.json();
const files = payload?.files || [];
let allFilesTerminal = files.length > 0;

for (const file of files) {
const fileId = file?.file_id;
const status = (file?.status || '').toLowerCase();
if (!fileId || !status) {
continue;
}

if (!['completed', 'failed', 'error'].includes(status)) {
allFilesTerminal = false;
}

const previousStatus = this.lastKnownStatus[fileId];
if (previousStatus !== status) {
this.lastKnownStatus[fileId] = status;

this._emit('message', {
batch_id: batchId,
file_id: fileId,
agent_type: 'Polling agent',
agent_message: `Status changed to ${status}`,
process_status: status,
file_result: file?.file_result || null,
});
}
}

if (allFilesTerminal) {
this.disconnect();
}
} catch (error) {
this._emit('error', error);
};
}
}

connect(batch_id: string): void {
if (this.isConnected && this.activeBatchId === batch_id) return;

this.disconnect();

this.isConnected = true;
this.activeBatchId = batch_id;
this.lastKnownStatus = {};
this._emit('open', undefined);

this.socket.onclose = (event: CloseEvent) => {
console.log('WebSocket closed:', event);
this._emit('close', event);
this.socket = null;
};
// Poll once immediately, then at a fixed interval.
void this.pollBatchSummary(batch_id);
this.pollInterval = setInterval(() => {
if (this.isConnected && this.activeBatchId) {
void this.pollBatchSummary(this.activeBatchId);
}
}, 3000);
}
Comment on lines +67 to 84

disconnect(): void {
if (this.socket) {
this.socket.close();
this.socket = null;
console.log('WebSocket connection closed manually.');
if (this.pollInterval) {
clearInterval(this.pollInterval);
this.pollInterval = null;
}

const wasConnected = this.isConnected;
this.isConnected = false;
this.activeBatchId = null;
this.lastKnownStatus = {};

if (wasConnected) {
this._emit('close', { reason: 'polling_stopped' });
}
}

send(data: any): void {
if (this.socket && this.socket.readyState === WebSocket.OPEN) {
this.socket.send(JSON.stringify(data));
} else {
console.error('WebSocket is not open. Cannot send:', data);
}
// Polling transport is read-only from client perspective.
console.debug('send() is ignored in polling mode:', data);
}

on(event: string, handler: EventHandler): void {
Expand Down
Loading
Loading