Skip to content

Commit 3503808

Browse files
committed
{Search for wildcards function updated}
1 parent 861ca07 commit 3503808

3 files changed

Lines changed: 191 additions & 16 deletions

File tree

datashuttle/utils/folders.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717
from datashuttle.utils.custom_types import TopLevelFolder
1818

1919
import glob
20+
import re
21+
from datetime import datetime
2022
from pathlib import Path
2123

2224
from datashuttle.configs import canonical_folders, canonical_tags
2325
from datashuttle.utils import ssh, utils, validation
2426
from datashuttle.utils.custom_exceptions import NeuroBlueprintError
27+
from datashuttle.utils.utils import get_values_from_bids_formatted_name
2528

2629
# -----------------------------------------------------------------------------
2730
# Create Folders
@@ -401,27 +404,65 @@ def search_for_wildcards(
401404
"""
402405
new_all_names: List[str] = []
403406
for name in all_names:
404-
if canonical_tags.tags("*") in name:
405-
name = name.replace(canonical_tags.tags("*"), "*")
406-
407-
matching_names: List[str]
407+
if canonical_tags.tags("*") in name or "@DATETO@" in name:
408+
search_str = name.replace(canonical_tags.tags("*"), "*")
409+
# If a date-range tag is present, extract dates and update the search string.
410+
if "@DATETO@" in name:
411+
m = re.search(r"(\d{8})@DATETO@(\d{8})", name)
412+
if not m:
413+
raise ValueError(
414+
"Invalid date range format in name: " + name
415+
)
416+
start_str, end_str = m.groups()
417+
try:
418+
start_date = datetime.strptime(start_str, "%Y%m%d")
419+
end_date = datetime.strptime(end_str, "%Y%m%d")
420+
except ValueError as e:
421+
raise ValueError("Invalid date in date range: " + str(e))
422+
# Replace the date-range substring with "date-*"
423+
search_str = re.sub(r"\d{8}@DATETO@\d{8}", "date-*", name)
424+
# Use the helper function to perform the glob search.
408425
if sub:
409-
matching_names = search_sub_or_ses_level( # type: ignore
410-
cfg, base_folder, local_or_central, sub, search_str=name
426+
matching_names: List[str] = search_sub_or_ses_level(
427+
cfg,
428+
base_folder,
429+
local_or_central,
430+
sub,
431+
search_str=search_str,
411432
)[0]
412433
else:
413-
matching_names = search_sub_or_ses_level( # type: ignore
414-
cfg, base_folder, local_or_central, search_str=name
434+
matching_names = search_sub_or_ses_level(
435+
cfg, base_folder, local_or_central, search_str=search_str
415436
)[0]
416-
437+
# If a date-range tag was provided, further filter the results.
438+
if "@DATETO@" in name:
439+
filtered_names: List[str] = []
440+
for candidate in matching_names:
441+
candidate_basename = (
442+
candidate
443+
if isinstance(candidate, str)
444+
else candidate.name
445+
)
446+
values_list = get_values_from_bids_formatted_name(
447+
[candidate_basename], "date"
448+
)
449+
if not values_list:
450+
continue
451+
candidate_date_str = values_list[0]
452+
try:
453+
candidate_date = datetime.strptime(
454+
candidate_date_str, "%Y%m%d"
455+
)
456+
except ValueError:
457+
continue
458+
if start_date <= candidate_date <= end_date:
459+
filtered_names.append(candidate)
460+
matching_names = filtered_names
417461
new_all_names += matching_names
418462
else:
419463
new_all_names += [name]
420-
421-
new_all_names = list(
422-
set(new_all_names)
423-
) # remove duplicate names in case of wildcard overlap
424-
464+
# Remove duplicates in case of wildcard overlap.
465+
new_all_names = list(set(new_all_names))
425466
return new_all_names
426467

427468

@@ -440,7 +481,7 @@ def search_sub_or_ses_level(
440481
search_str: str = "*",
441482
verbose: bool = True,
442483
return_full_path: bool = False,
443-
) -> Tuple[List[str] | List[Path], List[str]]:
484+
) -> Tuple[Union[List[str], List[Path]], List[str]]:
444485
"""
445486
Search project folder at the subject or session level.
446487
Only returns folders

datashuttle/utils/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def replace_tags_in_regexp(regexp: str) -> str:
321321
Note `replace_date_time_tags_in_name()` operates in place on a list.
322322
"""
323323
regexp_list = [regexp]
324-
date_regexp = "\d\d\d\d\d\d\d\d"
324+
date_regexp = r"\d{8}"
325325
time_regexp = "\d\d\d\d\d\d"
326326

327327
formatting.replace_date_time_tags_in_name(

tests/test_date_search_range.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import glob
2+
import os
3+
import re
4+
import shutil
5+
import tempfile
6+
from pathlib import Path
7+
from typing import List
8+
9+
import pytest
10+
11+
from datashuttle.utils.folders import search_for_wildcards
12+
13+
14+
# Dummy implementation for canonical_tags
15+
class DummyCanonicalTags:
16+
@staticmethod
17+
def tags(x: str) -> str:
18+
if x == "*":
19+
return "@*@"
20+
return x
21+
22+
23+
# Patch canonical_tags so that tags("*") returns "@*@"
24+
@pytest.fixture(autouse=True)
25+
def patch_canonical_tags(monkeypatch):
26+
from datashuttle.configs import canonical_tags
27+
28+
monkeypatch.setattr(canonical_tags, "tags", DummyCanonicalTags.tags)
29+
30+
31+
# Dummy implementation for search_sub_or_ses_level that simply performs globbing.
32+
def dummy_search_sub_or_ses_level(
33+
cfg, base_folder: Path, local_or_central: str, *args, search_str: str
34+
):
35+
pattern = os.path.join(str(base_folder), search_str)
36+
matches: List[str] = sorted(glob.glob(pattern))
37+
return (matches,)
38+
39+
40+
# Patch search_sub_or_ses_level in the module where search_for_wildcards is defined.
41+
@pytest.fixture(autouse=True)
42+
def patch_search_sub_or_ses_level(monkeypatch):
43+
monkeypatch.setattr(
44+
"datashuttle.utils.folders.search_sub_or_ses_level",
45+
dummy_search_sub_or_ses_level,
46+
)
47+
48+
49+
# Dummy implementation for get_values_from_bids_formatted_name.
50+
def dummy_get_values_from_bids_formatted_name(name: str, key: str) -> dict:
51+
# Expect name format: "sub-01_date-YYYYMMDD"
52+
m = re.search(r"date-(\d{8})", name)
53+
if m:
54+
return {key: m.group(1)}
55+
return {}
56+
57+
58+
# Patch get_values_from_bids_formatted_name.
59+
@pytest.fixture(autouse=True)
60+
def patch_get_values_from_bids(monkeypatch):
61+
monkeypatch.setattr(
62+
"datashuttle.utils.utils.get_values_from_bids_formatted_name",
63+
dummy_get_values_from_bids_formatted_name,
64+
)
65+
66+
67+
# Fixture to create a temporary directory with a simulated folder structure.
68+
@pytest.fixture
69+
def temp_project_dir() -> Path: # type: ignore
70+
temp_dir = Path(tempfile.mkdtemp())
71+
# Create folders with names in the format "sub-01_date-YYYYMMDD"
72+
folder_dates = [
73+
"20250305",
74+
"20250306",
75+
"20250307",
76+
"20250308",
77+
"20250309",
78+
"20250310",
79+
]
80+
for date_str in folder_dates:
81+
folder_name = f"sub-01_date-{date_str}"
82+
os.mkdir(temp_dir / folder_name)
83+
yield temp_dir
84+
shutil.rmtree(temp_dir)
85+
86+
87+
def test_date_range_wildcard(temp_project_dir: Path):
88+
"""
89+
When given a date-range wildcard pattern like "sub-01_20250306@DATETO@20250309",
90+
only folders whose embedded date falls between 20250306 and 20250309 (inclusive)
91+
should be returned.
92+
"""
93+
94+
class Configs:
95+
pass
96+
97+
cfg = Configs()
98+
base_folder = temp_project_dir
99+
local_or_central = "local"
100+
pattern = "sub-01_20250306@DATETO@20250309"
101+
result = search_for_wildcards(
102+
cfg, base_folder, local_or_central, [pattern]
103+
)
104+
105+
# Extract the dates from the returned folder names.
106+
found_dates = set()
107+
for folder in result:
108+
basename = os.path.basename(folder)
109+
m = re.search(r"date-(\d{8})", basename)
110+
if m:
111+
found_dates.add(m.group(1))
112+
113+
expected_dates = {"20250306", "20250307", "20250308", "20250309"}
114+
assert found_dates == expected_dates
115+
116+
117+
def test_simple_wildcard(temp_project_dir: Path):
118+
"""
119+
When given a simple wildcard pattern like "sub-01_@*@",
120+
all folders should be returned.
121+
"""
122+
123+
class Configs:
124+
pass
125+
126+
cfg = Configs()
127+
base_folder = temp_project_dir
128+
local_or_central = "local"
129+
pattern = "sub-01_@*@"
130+
result = search_for_wildcards(
131+
cfg, base_folder, local_or_central, [pattern]
132+
)
133+
# We expect six folders.
134+
assert len(result) == 6

0 commit comments

Comments
 (0)