Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/murfey/instrument_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,15 @@ def stop_multigrid_watcher(session_id: MurfeySessionID, label: str):
return {"success": True}


@router.get("/sessions/{session_id}/multigrid_controller/status")
def check_multigrid_controller_exists(
session_id: MurfeySessionID,
):
if controllers.get(session_id, None) is not None:
return {"exists": True}
return {"exists": False}


@router.post("/sessions/{session_id}/multigrid_controller/visit_end_time")
def update_multigrid_controller_visit_end_time(
session_id: MurfeySessionID, end_time: datetime
Expand Down
87 changes: 56 additions & 31 deletions src/murfey/server/api/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import datetime
import logging
from pathlib import Path
from typing import Annotated, List, Optional
from typing import Annotated, Any, List, Optional
from urllib.parse import quote

import aiohttp
Expand Down Expand Up @@ -101,6 +101,31 @@
return {"active": response.status == 200}


@router.get("/sessions/{session_id}/multigrid_controller/status")
async def check_multigrid_controller_exists(session_id: MurfeySessionID, db=murfey_db):
session = db.exec(select(Session).where(Session.id == session_id)).one()
instrument_name = session.instrument_name
machine_config = get_machine_config(instrument_name=instrument_name)[
instrument_name
]
if machine_config.instrument_server_url:
log.debug(
f"Submitting request to inspect multigrid controller for session {session_id}"
Comment thread Dismissed
)
async with aiohttp.ClientSession() as clientsession:
async with clientsession.get(
f"{machine_config.instrument_server_url}{url_path_for('api.router', 'check_multigrid_controller_exists', session_id=session_id)}",
headers={
"Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}"
},
) as resp:
data: dict[str, Any] = await resp.json()
else:
data = {"detail": "No instrument server URL found"}
log.debug(f"Received response: {data}")
return data


@router.post("/sessions/{session_id}/multigrid_watcher")
async def setup_multigrid_watcher(
session_id: MurfeySessionID, watcher_spec: MultigridWatcherSetup, db=murfey_db
Expand Down Expand Up @@ -165,6 +190,36 @@
return data


@router.post("/sessions/{session_id}/multigrid_controller/visit_end_time")
async def update_visit_end_time(
session_id: MurfeySessionID, end_time: datetime.datetime, db=murfey_db
):
# Load data for session
session_entry = db.exec(select(Session).where(Session.id == session_id)).one()
instrument_name = session_entry.instrument_name

# Update visit end time in database
session_entry.visit_end_time = end_time
db.add(session_entry)
db.commit()

# Update the multigrid controller
data = {}
machine_config = get_machine_config(instrument_name=instrument_name)[
instrument_name
]
if machine_config.instrument_server_url:
async with aiohttp.ClientSession() as clientsession:
async with clientsession.post(
f"{machine_config.instrument_server_url}{url_path_for('api.router', 'update_multigrid_controller_visit_end_time', session_id=session_id)}?end_time={quote(end_time.isoformat())}",
headers={
"Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}"
},
) as resp:
data = await resp.json()
return data


class ProvidedProcessingParameters(BaseModel):
dose_per_frame: float
extract_downscale: bool = True
Expand Down Expand Up @@ -397,36 +452,6 @@
return data


@router.post("/sessions/{session_id}/multigrid_controller/visit_end_time")
async def update_visit_end_time(
session_id: MurfeySessionID, end_time: datetime.datetime, db=murfey_db
):
# Load data for session
session_entry = db.exec(select(Session).where(Session.id == session_id)).one()
instrument_name = session_entry.instrument_name

# Update visit end time in database
session_entry.visit_end_time = end_time
db.add(session_entry)
db.commit()

# Update the multigrid controller
data = {}
machine_config = get_machine_config(instrument_name=instrument_name)[
instrument_name
]
if machine_config.instrument_server_url:
async with aiohttp.ClientSession() as clientsession:
async with clientsession.post(
f"{machine_config.instrument_server_url}{url_path_for('api.router', 'update_multigrid_controller_visit_end_time', session_id=session_id)}?end_time={quote(end_time.isoformat())}",
headers={
"Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}"
},
) as resp:
data = await resp.json()
return data


@router.post("/sessions/{session_id}/abandon_session")
async def abandon_session(session_id: MurfeySessionID, db=murfey_db):
data = {}
Expand Down
10 changes: 10 additions & 0 deletions src/murfey/util/route_manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ murfey.instrument_server.api.router:
path_params: []
methods:
- POST
- path: /sessions/{session_id}/multigrid_controller/status
function: check_multigrid_controller_exists
path_params: []
methods:
- GET
- path: /sessions/{session_id}/stop_rsyncer
function: stop_rsyncer
path_params: []
Expand Down Expand Up @@ -503,6 +508,11 @@ murfey.server.api.instrument.router:
path_params: []
methods:
- POST
- path: /instrument_server/sessions/{session_id}/multigrid_controller/status
function: check_multigrid_controller_exists
path_params: []
methods:
- GET
- path: /instrument_server/sessions/{session_id}/provided_processing_parameters
function: pass_proc_params_to_instrument_server
path_params: []
Expand Down
130 changes: 130 additions & 0 deletions tests/server/api/test_instrument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import Literal
from unittest import mock
from unittest.mock import AsyncMock, MagicMock

from fastapi import FastAPI
from fastapi.testclient import TestClient
from pytest_mock import MockerFixture

from murfey.server.api.auth import validate_frontend_session_access, validate_token
from murfey.server.api.instrument import router as backend_router
from murfey.server.murfey_db import murfey_db_session
from murfey.util.api import url_path_for


def mock_aiohttp_clientsession(
mocker: MockerFixture,
method: Literal["get", "post", "delete"] = "get",
json_data={},
status=200,
):
"""
Helper function to patch a aiohttp.ClientSession GET request. This returns a
mocked async context manager with a mocked response that, in turn, returns
the given JSON data and status.

Returns the mocked ClientSession, which can then be inspected to assert that
the expected calls were made.
"""

# Mock out the async response
mock_response = MagicMock()
mock_response.json = AsyncMock(return_value=json_data)
mock_response.status = status

# Mock out the context manager returned by clientsession.get()
mock_context_manager = MagicMock()
mock_context_manager.__aenter__ = AsyncMock(return_value=mock_response)
mock_context_manager.__aexit__ = AsyncMock(return_value=None)

# Mock the client session
mock_clientsession = MagicMock()
mock_clientsession.__aenter__ = AsyncMock(return_value=mock_clientsession)
mock_clientsession.__aexit__ = AsyncMock(return_value=None)

# Assign the context manager to the request method being tested
getattr(mock_clientsession, method.lower()).return_value = mock_context_manager

# Patch 'aiohttp.ClientSession' to return the mocked client session
mocker.patch("aiohttp.ClientSession", return_value=mock_clientsession)

return mock_clientsession, mock_response


def test_check_multigrid_controller_exists(mocker: MockerFixture):
# Set up the objects to mock
instrument_name = "test"
session_id = 1
instrment_server_url = "https://murfey.instrument-server.test"

# Override the database session generator
mock_session = MagicMock()
mock_session.instrument_name = instrument_name
mock_query_result = MagicMock()
mock_query_result.one.return_value = mock_session
mock_db_session = MagicMock()
mock_db_session.exec.return_value = mock_query_result

def mock_get_db_session():
yield mock_db_session

# Mock the machine config
mock_machine_config = MagicMock()
mock_machine_config.instrument_server_url = instrment_server_url
mock_get_machine_config = mocker.patch(
"murfey.server.api.instrument.get_machine_config"
)
mock_get_machine_config.return_value = {
instrument_name: mock_machine_config,
}

# Mock the instrument server tokens dictionary
mock_tokens = mocker.patch(
"murfey.server.api.instrument.instrument_server_tokens",
{session_id: {"access_token": mock.sentinel}},
)

# Mock out the async GET request in the endpoint
mock_clientsession, _ = mock_aiohttp_clientsession(
mocker,
method="get",
json_data={"exists": True},
status=200,
)

# Set up the backend server
backend_app = FastAPI()

# Override validation and database dependencies
backend_app.dependency_overrides[validate_token] = lambda: None
backend_app.dependency_overrides[validate_frontend_session_access] = (
lambda: session_id
)
backend_app.dependency_overrides[murfey_db_session] = mock_get_db_session
backend_app.include_router(backend_router)
backend_server = TestClient(backend_app)

# Construct the URL paths for poking and sending to
backend_url_path = url_path_for(
"api.instrument.router",
"check_multigrid_controller_exists",
session_id=session_id,
)
client_url_path = url_path_for(
"api.router",
"check_multigrid_controller_exists",
session_id=session_id,
)

# Poke the backend
response = backend_server.get(backend_url_path)

# Check that the expected calls were made
mock_db_session.exec.assert_called_once()
mock_get_machine_config.assert_called_once_with(instrument_name=instrument_name)
mock_clientsession.get.assert_called_once_with(
f"{instrment_server_url}{client_url_path}",
headers={"Authorization": f"Bearer {mock_tokens[session_id]['access_token']}"},
)
assert response.status_code == 200
assert response.json() == {"exists": True}