Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions src/murfey/server/api/session_shared.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from pathlib import Path
from typing import Dict, List

Expand Down Expand Up @@ -136,11 +137,40 @@ def get_foil_hole(session_id: int, fh_name: int, db) -> Dict[str, int]:
return {f[1].tag: f[0].id for f in foil_holes}


def find_upstream_visits(session_id: int, db: SQLModelSession):
def find_upstream_visits(session_id: int, db: SQLModelSession, max_depth: int = 2):
"""
Returns a nested dictionary, in which visits and the full paths to their directories
are further grouped by instrument name.
"""

def _recursive_search(
dirpath: str | Path,
search_string: str,
partial_match=True,
max_depth: int = 1,
):
# Stop recursing for this route once max depth hits 0
if max_depth == 0:
return
for entry in os.scandir(dirpath):
if entry.is_dir():
# Update dictionary with match and stop recursing for this route
if (
search_string in entry.name
if partial_match
else search_string == entry.name
):
current_upstream_visits[entry.name] = Path(entry.path)
Comment thread
tieneupin marked this conversation as resolved.
Outdated
else:
# Continue searching down this route until max depth is reached
_recursive_search(
dirpath=entry.path,
search_string=search_string,
partial_match=partial_match,
max_depth=max_depth - 1,
)
continue
Comment thread
tieneupin marked this conversation as resolved.
Outdated

murfey_session = db.exec(
select(MurfeySession).where(MurfeySession.id == session_id)
).one()
Expand All @@ -155,11 +185,14 @@ def find_upstream_visits(session_id: int, db: SQLModelSession):
upstream_instrument,
upstream_data_dir,
) in machine_config.upstream_data_directories.items():
# Looks for visit name in file path
current_upstream_visits = {}
for visit_path in Path(upstream_data_dir).glob(f"{visit_name.split('-')[0]}-*"):
if visit_path.is_dir():
current_upstream_visits[visit_path.name] = visit_path
# Recursively look for matching visit names under current directory
current_upstream_visits: dict[str, Path] = {}
_recursive_search(
dirpath=upstream_data_dir,
search_string=f"{visit_name.split('-')[0]}-",
partial_match=True,
max_depth=max_depth,
)
upstream_visits[upstream_instrument] = current_upstream_visits
return upstream_visits

Expand Down
8 changes: 6 additions & 2 deletions tests/server/api/test_session_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from tests.conftest import ExampleVisit


@pytest.mark.parametrize("recurse", (True, False))
def test_find_upstream_visits(
mocker: MockerFixture,
tmp_path: Path,
# murfey_db_session,
recurse: bool,
):
# Get the visit, instrument name, and session ID
visit_name_root = f"{ExampleVisit.proposal_code}{ExampleVisit.proposal_number}"
Expand Down Expand Up @@ -40,7 +41,10 @@ def test_find_upstream_visits(
# Only directories should be picked up
upstream_visit.mkdir(parents=True, exist_ok=True)
upstream_visits[upstream_instrument] = {upstream_visit.stem: upstream_visit}
upstream_data_dirs[upstream_instrument] = upstream_visit.parent
# Check that the function can cope with recursive searching
upstream_data_dirs[upstream_instrument] = (
upstream_visit.parent.parent if recurse else upstream_visit.parent
)
else:
upstream_visit.parent.mkdir(parents=True, exist_ok=True)
upstream_visit.touch(exist_ok=True)
Expand Down