Skip to content

Commit 4cef223

Browse files
author
Gerit Wagner
committed
load: revise entrytype setting
1 parent 39adf5c commit 4cef223

2 files changed

Lines changed: 20 additions & 9 deletions

File tree

colrev/loader/load_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,15 @@ def bib_entrytype_setter(entrytype: dict) -> None:
160160
entrytype[Fields.ENTRYTYPE] = ENTRYTYPES.MISC
161161

162162

163+
def _noop_entrytype_setter(entrytype: dict) -> None:
164+
"""Default for non-BibTeX loaders: do nothing."""
165+
return
166+
167+
163168
def load( # type: ignore
164169
filename: Path,
165170
*,
166-
entrytype_setter: typing.Callable = lambda x: x,
171+
entrytype_setter: typing.Callable[[dict], None] | None = None,
167172
field_mapper: typing.Callable = lambda x: x,
168173
id_labeler: typing.Callable = lambda x: x,
169174
unique_id_field: str = "",
@@ -183,6 +188,8 @@ def load( # type: ignore
183188

184189
if filename.suffix == ".bib":
185190
parser = colrev.loader.bib.BIBLoader # type: ignore
191+
if entrytype_setter is None:
192+
entrytype_setter = bib_entrytype_setter
186193
elif filename.suffix in [".csv", ".xls", ".xlsx"]:
187194
parser = colrev.loader.table.TableLoader # type: ignore
188195
elif filename.suffix == ".ris":
@@ -198,6 +205,10 @@ def load( # type: ignore
198205
else:
199206
raise NotImplementedError(f"Unsupported file type: {filename.suffix}")
200207

208+
# For non-bib files, if still None, use a no-op setter
209+
if entrytype_setter is None:
210+
entrytype_setter = _noop_entrytype_setter
211+
201212
return parser(
202213
filename=filename,
203214
entrytype_setter=entrytype_setter,

colrev/loader/loader.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ def _set_entrytypes(self, records_dict: dict) -> None:
6363
assert all(
6464
Fields.ENTRYTYPE in r for r in records_dict.values()
6565
), "ENTRYTYPE not set in all records"
66-
invalid_entrytypes = [
67-
r[Fields.ENTRYTYPE]
68-
for r in records_dict.values()
69-
if r[Fields.ENTRYTYPE] not in ENTRYTYPES.get_all()
70-
]
71-
assert (
72-
len(invalid_entrytypes) == 0
73-
), f"Invalid ENTRYTYPE in some records: {invalid_entrytypes}"
66+
67+
for r in records_dict.values():
68+
if r[Fields.ENTRYTYPE] in ENTRYTYPES.get_all():
69+
continue
70+
self.logger.warning(
71+
f"Invalid ENTRYTYPE in {r[Fields.ID]}: {r[Fields.ENTRYTYPE]}, setting to misc"
72+
)
73+
r[Fields.ENTRYTYPE] = ENTRYTYPES.MISC
7474

7575
def _set_fields(self, records_dict: dict) -> None:
7676
for record_dict in records_dict.values():

0 commit comments

Comments
 (0)