Skip to content

Commit 32c9d8f

Browse files
committed
bids tsv typing
1 parent eb62f84 commit 32c9d8f

6 files changed

Lines changed: 59 additions & 42 deletions

File tree

python/lib/physio/hed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def build_hed_tag_groups(hed_union: Sequence[DbHedSchemaNode], hed_string: str)
121121
return tag_groups
122122

123123

124-
def standardize_row_columns(row: dict[str, str | None]) -> dict[str, str | None]:
124+
def standardize_row_columns(row: dict[str, str | None]) -> dict[str, str]:
125125
"""
126126
Standardizes LORIS-recognized events.tsv columns to their DB column name
127127
@@ -130,7 +130,7 @@ def standardize_row_columns(row: dict[str, str | None]) -> dict[str, str | None]
130130
:return: Standardized row
131131
"""
132132

133-
standardized_row: dict[str, Any] = {}
133+
standardized_row: dict[str, str] = {}
134134
recognized_event_fields = [
135135
'Onset', 'Duration', 'TrialType',
136136
'ResponseTime', 'EventCode',

python/loris_bids_importer/src/loris_bids_importer/events.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def insert_bids_events_file(
116116
for row in events_file.rows:
117117
# has additional fields?
118118
additional_fields: dict[str, str] = {}
119-
for field in row.data:
120-
if field not in known_fields and str(row.data[field]).lower() != 'nan':
121-
additional_fields[field] = row.data[field]
119+
for field, value in row.data.items():
120+
if field not in known_fields and value is not None and value.lower() != 'nan':
121+
additional_fields[field] = value
122122

123123
# insert one event and get its db id
124124
task_event = insert_physio_task_event(
@@ -137,8 +137,9 @@ def insert_bids_events_file(
137137

138138
# Insert HED tags after filtering out inherited tags from events.json, so that they are
139139
# not "duplicated"
140-
if row.data.get('HED') is not None and len(row.data['HED']) > 0 and row.data['HED'] != 'n/a':
141-
tag_groups = build_hed_tag_groups(hed_union, row.data['HED'])
140+
hed = row.data.get('HED')
141+
if hed is not None and len(hed) > 0 and hed != 'n/a':
142+
tag_groups = build_hed_tag_groups(hed_union, hed)
142143
tag_groups_without_inherited = filter_inherited_tags(
143144
row.data, tag_groups, dataset_tag_dict, file_tag_dict
144145
)

python/loris_bids_importer/src/loris_bids_importer/validation/subjects.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -131,24 +131,24 @@ def get_bids_participant_row_sex(env: Env, participant: BidsParticipantTsvRow) -
131131
Raise an exception if a sex is specified but does not exist in LORIS.
132132
"""
133133

134-
if 'sex' not in participant.data:
134+
if participant.sex is None:
135135
return None
136136

137-
tsv_participant_sex = participant.data['sex'].lower()
137+
participant_sex = participant.sex.lower()
138138

139-
if tsv_participant_sex in ['m', 'male']:
139+
if participant_sex in ['m', 'male']:
140140
sex_name = 'Male'
141-
elif tsv_participant_sex in ['f', 'female']:
141+
elif participant_sex in ['f', 'female']:
142142
sex_name = 'Female'
143-
elif tsv_participant_sex in ['o', 'other']:
143+
elif participant_sex in ['o', 'other']:
144144
sex_name = 'Other'
145145
else:
146-
sex_name = participant.data['sex']
146+
sex_name = participant.sex
147147

148148
sex = try_get_sex_with_name(env.db, sex_name)
149149
if sex is None:
150150
raise Exception(
151-
f"No LORIS sex found for the BIDS participants.tsv sex name or alias '{participant.data['sex']}'."
151+
f"No LORIS sex found for the BIDS participants.tsv sex name or alias '{participant.sex}'."
152152
)
153153

154154
return sex.name
@@ -160,22 +160,22 @@ def get_bids_participant_row_site(env: Env, participant: BidsParticipantTsvRow)
160160
specified or does not exist in LORIS.
161161
"""
162162

163-
if 'site' not in participant.data:
163+
if participant.site is None:
164164
raise Exception(
165165
"No 'site' column found in the BIDS participants.tsv file, this field is required to create candidates or"
166166
" sessions. "
167167
)
168168

169-
site = try_get_site_with_name(env.db, participant.data['site'])
169+
site = try_get_site_with_name(env.db, participant.site)
170170
if site is not None:
171171
return site
172172

173-
site = try_get_site_with_alias(env.db, participant.data['site'])
173+
site = try_get_site_with_alias(env.db, participant.site)
174174
if site is not None:
175175
return site
176176

177177
raise Exception(
178-
f"No site found for the BIDS participants.tsv site name or alias '{participant.data['site']}'."
178+
f"No site found for the BIDS participants.tsv site name or alias '{participant.site}'."
179179
)
180180

181181

@@ -185,20 +185,20 @@ def get_bids_participant_row_project(env: Env, participant: BidsParticipantTsvRo
185185
specified or does not exist in LORIS.
186186
"""
187187

188-
if 'project' not in participant.data:
188+
if participant.project is None:
189189
raise Exception(
190190
"No 'project' column found in the BIDS participants.tsv file, this field is required to create candidates"
191191
" or sessions. "
192192
)
193193

194-
project = try_get_project_with_name(env.db, participant.data['project'])
194+
project = try_get_project_with_name(env.db, participant.project)
195195
if project is not None:
196196
return project
197197

198-
project = try_get_project_with_alias(env.db, participant.data['project'])
198+
project = try_get_project_with_alias(env.db, participant.project)
199199
if project is not None:
200200
return project
201201

202202
raise Exception(
203-
f"No project found for the BIDS participants.tsv project name or alias '{participant.data['project']}'."
203+
f"No project found for the BIDS participants.tsv project name or alias '{participant.project}'."
204204
)

python/loris_bids_reader/src/loris_bids_reader/files/participants.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,35 @@ class BidsParticipantTsvRow(BidsTsvRow):
1616
"""
1717

1818
participant_id: str
19-
birth_date: date | None
19+
project: str | None
20+
site: str | None
2021
cohort: str | None
22+
birth_date: date | None
23+
sex: str | None
2124

22-
def __init__(self, data: dict[str, str]):
25+
def __init__(self, data: dict[str, str | None]):
2326
super().__init__(data)
24-
self.participant_id = data['participant_id'].removeprefix('sub-')
25-
self.birth_date = self._read_birth_date()
27+
participant_id = self.data.get('participant_id')
28+
if participant_id is None:
29+
raise Exception("Missing participant_id field in `participants.tsv` file.")
30+
31+
self.participant_id = participant_id.removeprefix('sub-')
32+
self.project = self.data.get('project')
33+
self.site = self.data.get('site')
2634
self.cohort = self._read_cohort()
35+
self.birth_date = self._read_birth_date()
36+
self.sex = self.data.get('sex')
2737

2838
def _read_birth_date(self) -> date | None:
2939
"""
3040
Read the date of birth field from this row data.
3141
"""
3242

3343
for birth_date_field_name in ['date_of_birth', 'birth_date', 'dob']:
34-
if birth_date_field_name in self.data:
44+
birth_date_string = self.data.get(birth_date_field_name)
45+
if birth_date_string is not None:
3546
try:
36-
return dateutil.parser.parse(self.data[birth_date_field_name]).date()
47+
return dateutil.parser.parse(birth_date_string).date()
3748
except ParserError:
3849
pass
3950

python/loris_bids_reader/src/loris_bids_reader/files/scans.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,16 @@ def get_acquisition_time(self) -> datetime | None:
1919
Get the acquisition time of the acquisition file.
2020
"""
2121

22-
if 'acq_time' in self.data:
23-
# the variable name could be mri_acq_time, but is eeg originally.
24-
eeg_acq_time = self.data['acq_time']
25-
26-
if eeg_acq_time == 'n/a':
22+
acq_time_string = self.data.get('acq_time')
23+
if acq_time_string is not None:
24+
if acq_time_string == 'n/a':
2725
return None
2826

2927
try:
30-
eeg_acq_time = dateutil.parser.parse(eeg_acq_time)
28+
acq_time = dateutil.parser.parse(acq_time_string)
3129
except ValueError as e:
32-
raise Exception(f"Could not convert acquisition time {eeg_acq_time}' to datetime: {e}")
33-
return eeg_acq_time
30+
raise Exception(f"Could not convert acquisition time {acq_time_string}' to datetime: {e}")
31+
return acq_time
3432

3533
return None
3634

@@ -43,8 +41,9 @@ def get_age_at_scan(self) -> str | None:
4341
age_header_list = ['age', 'age_at_scan', 'age_acq_time']
4442

4543
for header_name in age_header_list:
46-
if header_name in self.data:
47-
return self.data[header_name].strip()
44+
age_string = self.data.get(header_name)
45+
if age_string is not None:
46+
return age_string.strip()
4847

4948
return None
5049

@@ -64,7 +63,7 @@ def get_row(self, file_path: Path) -> BidsScanTsvRow | None:
6463
Get the row corresponding to the given file path.
6564
"""
6665

67-
return find(self.rows, lambda row: file_path.name in row.data['filename'])
66+
return find(self.rows, lambda row: file_path.name == row.data['filename'])
6867

6968
def set_row(self, scan: BidsScanTsvRow):
7069
"""

python/loris_bids_reader/src/loris_bids_reader/tsv.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import csv
22
from pathlib import Path
3-
from typing import Any, Generic, TypeVar
3+
from typing import Generic, TypeVar
44

55
from loris_utils.parse import nullify_empty_string
66

@@ -11,9 +11,9 @@ class BidsTsvRow:
1111
Documentation: https://bids-specification.readthedocs.io/en/stable/common-principles.html#tabular-files
1212
"""
1313

14-
data: dict[str, Any]
14+
data: dict[str, str | None]
1515

16-
def __init__(self, data: dict[str, Any]):
16+
def __init__(self, data: dict[str, str | None]):
1717
self.data = data
1818

1919

@@ -33,9 +33,15 @@ def __init__(self, model: type[T], path: Path):
3333
self.path = path
3434
self.rows = []
3535

36+
# The 'utf-8-sig' encoding is used to support some datasets where metadata files may contain
37+
# a byte-order mark (BOM).
3638
with open(self.path, encoding='utf-8-sig') as file:
3739
reader = csv.DictReader(file, delimiter='\t')
3840
for row in reader:
41+
# Skip empty lines (such as trailing newlines).
42+
if row == {}:
43+
continue
44+
3945
row = {key: nullify_empty_string(value) for key, value in row.items()}
4046
self.rows.append(model(row))
4147

0 commit comments

Comments
 (0)