Skip to content

Commit efd0a54

Browse files
committed
refactor colrev_records_variable_naming_convention.py
1 parent 14fff19 commit efd0a54

1 file changed

Lines changed: 126 additions & 98 deletions

File tree

colrev/linter/colrev_records_variable_naming_convention.py

Lines changed: 126 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -44,129 +44,157 @@ class RecordsVariableNamingConventionChecker(checkers.BaseChecker):
4444
),
4545
}
4646

47-
# pylint: disable=too-many-branches
48-
@only_required_for_messages(
49-
"colrev-records-variable-naming-convention",
50-
"colrev-record-cannot-be-dict",
51-
"colrev-records-must-be-dict",
52-
)
53-
def visit_assign(self, node: nodes.Assign) -> None:
54-
"""Check variable naming and simple type rules for 'record' and 'records'."""
47+
def _is_supported_assignment(self, node: nodes.Assign) -> bool:
5548
if len(node.targets) != 1: # pragma: no cover
56-
return
49+
return False
5750

58-
target = node.targets[0]
59-
if not hasattr(target, "name"):
60-
return
51+
return hasattr(node.targets[0], "name")
6152

62-
assigned = node.value
53+
def _get_assignment_target_name(self, node: nodes.Assign) -> str | None:
54+
target = node.targets[0]
55+
return getattr(target, "name", None)
6356

64-
# --- 1) ORIGINAL NAMING RULE FOR load_records_dict ---
65-
if (
57+
def _check_load_records_dict_naming(
58+
self, node: nodes.Assign, target_name: str, assigned: nodes.NodeNG
59+
) -> None:
60+
if not (
6661
hasattr(assigned, "func")
6762
and isinstance(assigned.func, nodes.Attribute)
6863
and assigned.func.attrname == "load_records_dict"
6964
):
70-
expected_name = "records"
71-
if hasattr(assigned, "keywords"):
72-
for keyword in assigned.keywords:
73-
if keyword.arg == "header_only":
74-
if getattr(keyword.value, "value", None) is True:
75-
expected_name = "records_headers"
76-
else:
77-
expected_name = "records"
78-
break
79-
80-
if target.name != expected_name:
81-
# Use the symbolic message id
82-
self.add_message("colrev-records-variable-naming-convention", node=node)
83-
84-
# --- 2) SIMPLE TYPE RULES FOR 'record' and 'records' ---
65+
return
66+
67+
expected_name = "records"
68+
if hasattr(assigned, "keywords"):
69+
for keyword in assigned.keywords:
70+
if keyword.arg == "header_only":
71+
if getattr(keyword.value, "value", None) is True:
72+
expected_name = "records_headers"
73+
else:
74+
expected_name = "records"
75+
break
76+
77+
if target_name != expected_name:
78+
self.add_message("colrev-records-variable-naming-convention", node=node)
79+
80+
def _infer_assigned_nodes(self, assigned: nodes.NodeNG) -> list[nodes.NodeNG]:
8581
try:
8682
inferred = list(assigned.infer())
8783
except (InferenceError, StopIteration): # pragma: no cover
8884
inferred = []
8985

90-
# Filter out astroid's Uninferable sentinel and ignore if nothing remains.
91-
inferred = [n for n in inferred if n is not Uninferable]
92-
if not inferred:
93-
return
86+
return [n for n in inferred if n is not Uninferable]
9487

95-
def _safe_pytype(n: nodes.NodeNG) -> str | None:
96-
pytype = getattr(n, "pytype", None)
97-
if not callable(pytype):
98-
return None
99-
100-
try:
101-
return pytype()
102-
except (AttributeError, InferenceError, TypeError):
103-
LOGGER.debug(
104-
"Could not infer pytype for astroid node %s",
105-
n.__class__.__name__,
106-
exc_info=True,
107-
)
108-
return None
109-
110-
def _safe_qname(n: nodes.NodeNG) -> str | None:
111-
qname = getattr(n, "qname", None)
112-
if not callable(qname):
113-
return None
114-
115-
try:
116-
return qname()
117-
except (AttributeError, InferenceError, TypeError):
118-
LOGGER.debug(
119-
"Could not infer qname for astroid node %s",
120-
n.__class__.__name__,
121-
exc_info=True,
122-
)
123-
return None
124-
125-
def _is_dict(n: nodes.NodeNG) -> bool:
126-
if isinstance(n, nodes.Dict):
127-
return True
128-
129-
return (
130-
_safe_pytype(n) == "builtins.dict" or _safe_qname(n) == "builtins.dict"
131-
)
88+
@staticmethod
89+
def _safe_pytype(n: nodes.NodeNG) -> str | None:
90+
pytype = getattr(n, "pytype", None)
91+
if not callable(pytype):
92+
return None
13293

133-
def _type_str(n: nodes.NodeNG) -> str:
134-
return _safe_pytype(n) or _safe_qname(n) or n.__class__.__name__
94+
try:
95+
return pytype()
96+
except (AttributeError, InferenceError, TypeError):
97+
LOGGER.debug(
98+
"Could not infer pytype for astroid node %s",
99+
n.__class__.__name__,
100+
exc_info=True,
101+
)
102+
return None
135103

136-
any_dict = any(_is_dict(n) for n in inferred)
104+
@staticmethod
105+
def _safe_qname(n: nodes.NodeNG) -> str | None:
106+
qname = getattr(n, "qname", None)
107+
if not callable(qname):
108+
return None
137109

138-
# Helper: is this a call to load_records_dict?
139-
def _is_call_to_load_records_dict(n: nodes.NodeNG) -> bool:
140-
return (
141-
isinstance(n, nodes.Call)
142-
and isinstance(n.func, nodes.Attribute)
143-
and n.func.attrname == "load_records_dict"
110+
try:
111+
return qname()
112+
except (AttributeError, InferenceError, TypeError):
113+
LOGGER.debug(
114+
"Could not infer qname for astroid node %s",
115+
n.__class__.__name__,
116+
exc_info=True,
117+
)
118+
return None
119+
120+
def _is_dict(self, n: nodes.NodeNG) -> bool:
121+
if isinstance(n, nodes.Dict):
122+
return True
123+
124+
return (
125+
self._safe_pytype(n) == "builtins.dict"
126+
or self._safe_qname(n) == "builtins.dict"
127+
)
128+
129+
def _type_str(self, n: nodes.NodeNG) -> str:
130+
return self._safe_pytype(n) or self._safe_qname(n) or n.__class__.__name__
131+
132+
@staticmethod
133+
def _is_call_to_load_records_dict(n: nodes.NodeNG) -> bool:
134+
return (
135+
isinstance(n, nodes.Call)
136+
and isinstance(n.func, nodes.Attribute)
137+
and n.func.attrname == "load_records_dict"
138+
)
139+
140+
def _check_record_assignment(
141+
self, node: nodes.Assign, inferred: list[nodes.NodeNG]
142+
) -> None:
143+
any_dict = any(self._is_dict(n) for n in inferred)
144+
if any_dict:
145+
dict_node = next((n for n in inferred if self._is_dict(n)), None)
146+
self.add_message(
147+
"colrev-record-cannot-be-dict",
148+
node=node,
149+
args=(self._type_str(dict_node) if dict_node else "builtins.dict",),
144150
)
145151

146-
# BEFORE running inference, treat direct calls to load_records_dict as dict.
147-
if target.name == "records" and _is_call_to_load_records_dict(assigned):
152+
def _check_records_assignment(
153+
self, node: nodes.Assign, assigned: nodes.NodeNG, inferred: list[nodes.NodeNG]
154+
) -> None:
155+
if self._is_call_to_load_records_dict(assigned):
148156
return
149157

150-
if target.name == "record" and any_dict:
151-
dict_node = next((n for n in inferred if _is_dict(n)), None)
158+
any_dict = any(self._is_dict(n) for n in inferred)
159+
if any_dict:
160+
return
161+
162+
rep = next(iter(inferred), None)
163+
if rep is None:
164+
return
165+
rep_str = self._type_str(rep)
166+
if rep_str not in ("unknown", "Uninferable", "astroid.util.Uninferable"):
152167
self.add_message(
153-
"colrev-record-cannot-be-dict",
168+
"colrev-records-must-be-dict",
154169
node=node,
155-
args=(_type_str(dict_node) if dict_node else "builtins.dict",),
170+
args=(rep_str,),
156171
)
157172

158-
if target.name == "records" and not any_dict:
159-
rep = next(iter(inferred), None)
160-
if rep is None:
161-
return
162-
rep_str = _type_str(rep)
163-
# Ignore unknown or Uninferable-like representations
164-
if rep_str not in ("unknown", "Uninferable", "astroid.util.Uninferable"):
165-
self.add_message(
166-
"colrev-records-must-be-dict",
167-
node=node,
168-
args=(rep_str,),
169-
)
173+
@only_required_for_messages(
174+
"colrev-records-variable-naming-convention",
175+
"colrev-record-cannot-be-dict",
176+
"colrev-records-must-be-dict",
177+
)
178+
def visit_assign(self, node: nodes.Assign) -> None:
179+
"""Check variable naming and simple type rules for 'record' and 'records'."""
180+
if not self._is_supported_assignment(node):
181+
return
182+
183+
target_name = self._get_assignment_target_name(node)
184+
if target_name is None:
185+
return
186+
187+
assigned = node.value
188+
self._check_load_records_dict_naming(node, target_name, assigned)
189+
190+
inferred = self._infer_assigned_nodes(assigned)
191+
if not inferred:
192+
return
193+
194+
if target_name == "record":
195+
self._check_record_assignment(node, inferred)
196+
elif target_name == "records":
197+
self._check_records_assignment(node, assigned, inferred)
170198

171199

172200
def register(linter: PyLinter) -> None: # pragma: no cover

0 commit comments

Comments
 (0)