Skip to content

Commit c7ff230

Browse files
committed
Improved dataloading functions with guaranteed matching RSRP and SINR filenames.
1 parent 3c38eab commit c7ff230

2 files changed

Lines changed: 26 additions & 32 deletions

File tree

src/ho_optim_drl/dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_filenames(dir_: str, prefix: str) -> list[str]:
2727
list[str]
2828
A list of filenames in the specified directory.
2929
"""
30-
return [f for f in os.listdir(dir_) if prefix in f]
30+
return sorted([f for f in os.listdir(dir_) if prefix in f])
3131

3232

3333
def load_preprocess_dataset(

src/ho_optim_drl/utils.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,15 @@
44
import numpy as np
55

66

7-
def extract_speed(filename: str) -> int:
8-
"""
9-
Extract speed from filename.
10-
11-
Parameters
12-
----------
13-
filename : str
14-
Filename string.
15-
16-
Returns
17-
-------
18-
int
19-
Speed in km/h.
20-
"""
21-
match = re.search(r"(\d+)kmh", filename)
22-
if not match:
23-
raise ValueError(f"Filename '{filename}' does not contain speed information.")
24-
return int(match.group(1))
25-
26-
277
def filenames_speed_filter(
288
rsrp_filenames: list[str],
299
sinr_filenames: list[str],
3010
use_speed_list: list[int],
3111
) -> tuple[list[str], list[str], list[int]]:
3212
"""
33-
Filter the files based on the speed list.
13+
Return sorted, speed-filtered RSRP- and SINR-filenames together with their speeds.
14+
Only pairs that exist **in both** input lists and whose speed is contained in
15+
``use_speed_list`` are kept.
3416
3517
Parameters
3618
----------
@@ -44,17 +26,29 @@ def filenames_speed_filter(
4426
Returns
4527
-------
4628
tuple[list[str], list[str], list[int]]
47-
Filtered RSRP filenames, SINR filenames, and speeds
29+
Filtered RSRP filenames, SINR filenames, and speeds.
4830
"""
49-
# Extract speed from filename "ue999kmh_...mat"
50-
speeds = [extract_speed(f) for f in rsrp_filenames]
51-
52-
# Filter the dataset based on the speed
53-
idxs = [i for i, speed in enumerate(speeds) if speed in use_speed_list]
54-
rsrp_filenames = [rsrp_filenames[i] for i in idxs]
55-
sinr_filenames = [sinr_filenames[i] for i in idxs]
56-
speeds = [speeds[i] for i in idxs]
57-
return rsrp_filenames, sinr_filenames, speeds
31+
32+
def _parse_key(fname: str) -> tuple[int, int]:
33+
m = re.search(r"_(\d+)kmh_(\d+)\.mat$", fname)
34+
if not m:
35+
raise ValueError(f"Filename does not match expected pattern: {fname}")
36+
speed, uid = map(int, m.groups())
37+
return speed, uid
38+
39+
rsrp_dict = {_parse_key(f): f for f in rsrp_filenames}
40+
sinr_dict = {_parse_key(f): f for f in sinr_filenames}
41+
42+
common_keys = [
43+
key for key in rsrp_dict.keys() & sinr_dict.keys() if key[0] in use_speed_list
44+
]
45+
common_keys.sort(key=lambda k: (k[0], k[1]))
46+
47+
rsrp_out = [rsrp_dict[k] for k in common_keys]
48+
sinr_out = [sinr_dict[k] for k in common_keys]
49+
speeds = [k[0] for k in common_keys]
50+
51+
return rsrp_out, sinr_out, speeds
5852

5953

6054
def get_sync_state(sinr_db: float, q_in_db: float, q_out_db: float):

0 commit comments

Comments
 (0)