Skip to content

Commit b966247

Browse files
csv_overrides: Add stronger type hints (#1851)
## Checklist - [ ] I have added [tests](https://www.cursorless.org/docs/contributing/test-case-recorder/) - [ ] I have updated the [docs](https://github.com/cursorless-dev/cursorless/tree/main/docs) and [cheatsheet](https://github.com/cursorless-dev/cursorless/tree/main/cursorless-talon/src/cheatsheet) - [ ] I have not broken the cheatsheet --------- Co-authored-by: Andreas Arvidsson <andreas.arvidsson87@gmail.com>
1 parent 52e868c commit b966247

2 files changed

Lines changed: 85 additions & 71 deletions

File tree

src/csv_overrides.py

Lines changed: 73 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import csv
22
import typing
33
from collections import defaultdict
4-
from collections.abc import Container
54
from dataclasses import dataclass
65
from datetime import datetime
76
from pathlib import Path
8-
from typing import Callable, Iterable, Optional, TypedDict
7+
from typing import Callable, Container, Iterable, Optional, Sequence, TypedDict
98

109
from talon import Context, Module, actions, app, fs, settings
1110

@@ -49,6 +48,12 @@ class SpokenFormEntry:
4948
spoken_forms: list[str]
5049

5150

51+
class ResultsListEntry(TypedDict):
52+
spoken: str
53+
id: str
54+
list: str
55+
56+
5257
def csv_get_ctx():
5358
return ctx
5459

@@ -60,17 +65,17 @@ def csv_get_normalized_ctx():
6065
def init_csv_and_watch_changes(
6166
filename: str,
6267
default_values: ListToSpokenForms,
63-
handle_new_values: Optional[Callable[[list[SpokenFormEntry]], None]] = None,
68+
handle_new_values: Optional[Callable[[Sequence[SpokenFormEntry]], None]] = None,
6469
*,
65-
extra_ignored_values: Optional[list[str]] = None,
66-
extra_allowed_values: Optional[list[str]] = None,
70+
extra_ignored_values: Optional[Sequence[str]] = None,
71+
extra_allowed_values: Optional[Sequence[str]] = None,
6772
allow_unknown_values: bool = False,
6873
deprecated: bool = False,
6974
default_list_name: Optional[str] = None,
70-
headers: list[str] = [SPOKEN_FORM_HEADER, CURSORLESS_IDENTIFIER_HEADER],
75+
headers: Optional[Sequence[str]] = None,
7176
no_update_file: bool = False,
72-
pluralize_lists: Optional[list[str]] = None,
73-
):
77+
pluralize_lists: Optional[Sequence[str]] = None,
78+
) -> Callable[[], None]:
7479
"""
7580
Initialize a cursorless settings csv, creating it if necessary, and watch
7681
for changes to the csv. Talon lists will be generated based on the keys of
@@ -91,21 +96,21 @@ def init_csv_and_watch_changes(
9196
`cursorles-settings` dir
9297
default_values (ListToSpokenForms): The default values for the lists to
9398
be customized in the given csv
94-
handle_new_values (Optional[Callable[[list[SpokenFormEntry]], None]]): A
99+
handle_new_values (Optional[Callable[[Sequence[SpokenFormEntry]], None]]): A
95100
callback to be called when the lists are updated
96-
extra_ignored_values (Optional[list[str]]): Don't throw an exception if
101+
extra_ignored_values (Optional[Sequence[str]]): Don't throw an exception if
97102
any of these appear as values; just ignore them and don't add them
98103
to any list
99104
allow_unknown_values (bool): If unknown values appear, just put them in
100105
the list
101106
default_list_name (Optional[str]): If unknown values are
102107
allowed, put any unknown values in this list
103-
headers (list[str]): The headers to use for the csv
108+
headers (Optional[Sequence[str]]): The headers to use for the csv
104109
no_update_file (bool): Set this to `True` to indicate that we should not
105110
update the csv. This is used generally in case there was an issue
106111
coming up with the default set of values so we don't want to persist
107112
those to disk
108-
pluralize_lists (list[str]): Create plural version of given lists
113+
pluralize_lists (Optional[Sequence[str]]): Create plural version of given lists
109114
"""
110115
# Don't allow both `extra_allowed_values` and `allow_unknown_values`
111116
assert not (extra_allowed_values and allow_unknown_values)
@@ -116,6 +121,8 @@ def init_csv_and_watch_changes(
116121
(extra_allowed_values or allow_unknown_values) and not default_list_name
117122
)
118123

124+
if headers is None:
125+
headers = (SPOKEN_FORM_HEADER, CURSORLESS_IDENTIFIER_HEADER)
119126
if extra_ignored_values is None:
120127
extra_ignored_values = []
121128
if extra_allowed_values is None:
@@ -137,7 +144,7 @@ def init_csv_and_watch_changes(
137144
check_for_duplicates(filename, default_values)
138145
create_default_vocabulary_dicts(default_values, pluralize_lists)
139146

140-
def on_watch(path, flags):
147+
def on_watch(path: str, _flags) -> None:
141148
if file_path.match(path):
142149
current_values, has_errors = read_file(
143150
path=file_path,
@@ -194,16 +201,16 @@ def on_watch(path, flags):
194201
handle_new_values=handle_new_values,
195202
)
196203

197-
def unsubscribe():
204+
def unsubscribe() -> None:
198205
fs.unwatch(file_path.parent, on_watch)
199206

200207
return unsubscribe
201208

202209

203-
def check_for_duplicates(filename, default_values):
210+
def check_for_duplicates(filename: str, default_values: ListToSpokenForms):
204211
results_map = {}
205-
for list_name, dict in default_values.items():
206-
for key, value in dict.items():
212+
for list_name, values in default_values.items():
213+
for key, value in values.items():
207214
if value in results_map:
208215
existing_list_name = results_map[value]
209216
warning = f"WARNING ({filename}): Value `{value}` duplicated between lists '{existing_list_name}' and '{list_name}'"
@@ -213,16 +220,17 @@ def check_for_duplicates(filename, default_values):
213220
results_map[value] = list_name
214221

215222

216-
def is_removed(value: str):
223+
def is_removed(value: str) -> bool:
217224
return value.startswith("-")
218225

219226

220227
def create_default_vocabulary_dicts(
221-
default_values: dict[str, dict], pluralize_lists: list[str]
228+
default_values: ListToSpokenForms,
229+
pluralize_lists: Sequence[str],
222230
):
223231
default_values_updated = {}
224232
for key, value in default_values.items():
225-
updated_dict = {}
233+
updated_dict: dict[str, str] = {}
226234
for key2, value2 in value.items():
227235
# Enable deactivated(prefixed with a `-`) items
228236
active_key = key2[1:] if key2.startswith("-") else key2
@@ -235,17 +243,17 @@ def create_default_vocabulary_dicts(
235243
def update_dicts(
236244
default_values: ListToSpokenForms,
237245
current_values: dict[str, str],
238-
extra_ignored_values: list[str],
239-
extra_allowed_values: list[str],
246+
extra_ignored_values: Sequence[str],
247+
extra_allowed_values: Sequence[str],
240248
allow_unknown_values: bool,
241249
default_list_name: str | None,
242-
pluralize_lists: list[str],
243-
handle_new_values: Callable[[list[SpokenFormEntry]], None] | None,
244-
):
250+
pluralize_lists: Sequence[str],
251+
handle_new_values: Callable[[Sequence[SpokenFormEntry]], None] | None,
252+
) -> None:
245253
# Create map with all default values
246254
results_map: dict[str, ResultsListEntry] = {}
247-
for list_name, obj in default_values.items():
248-
for spoken, id in obj.items():
255+
for list_name, values in default_values.items():
256+
for spoken, id in values.items():
249257
results_map[id] = {"spoken": spoken, "id": id, "list": list_name}
250258

251259
# Update result with current values
@@ -281,13 +289,9 @@ def update_dicts(
281289
handle_new_values(spoken_form_entries)
282290

283291

284-
class ResultsListEntry(TypedDict):
285-
spoken: str
286-
id: str
287-
list: str
288-
289-
290-
def generate_spoken_forms(results_list: Iterable[ResultsListEntry]):
292+
def generate_spoken_forms(
293+
results_list: Iterable[ResultsListEntry],
294+
) -> Iterable[SpokenFormEntry]:
291295
for obj in results_list:
292296
id = obj["id"]
293297
spoken = obj["spoken"]
@@ -315,25 +319,25 @@ def generate_spoken_forms(results_list: Iterable[ResultsListEntry]):
315319
def assign_lists_to_context(
316320
ctx: Context,
317321
lists: ListToSpokenForms,
318-
pluralize_lists: list[str],
319-
):
320-
for list_name, dict in lists.items():
322+
pluralize_lists: Sequence[str],
323+
) -> None:
324+
for list_name, values in lists.items():
321325
list_singular_name = get_cursorless_list_name(list_name)
322-
ctx.lists[list_singular_name] = dict
326+
ctx.lists[list_singular_name] = values
323327
if list_name in pluralize_lists:
324328
list_plural_name = f"{list_singular_name}_plural"
325-
ctx.lists[list_plural_name] = {pluralize(k): v for k, v in dict.items()}
329+
ctx.lists[list_plural_name] = {pluralize(k): v for k, v in values.items()}
326330

327331

328332
def update_file(
329333
path: Path,
330-
headers: list[str],
334+
headers: Sequence[str],
331335
default_values: dict[str, str],
332-
extra_ignored_values: list[str],
333-
extra_allowed_values: list[str],
336+
extra_ignored_values: Sequence[str],
337+
extra_allowed_values: Sequence[str],
334338
allow_unknown_values: bool,
335339
no_update_file: bool,
336-
):
340+
) -> dict[str, str]:
337341
current_values, has_errors = read_file(
338342
path=path,
339343
headers=headers,
@@ -344,7 +348,7 @@ def update_file(
344348
)
345349
current_identifiers = current_values.values()
346350

347-
missing = {}
351+
missing: dict[str, str] = {}
348352
for key, value in default_values.items():
349353
if value not in current_identifiers:
350354
missing[key] = value
@@ -357,16 +361,17 @@ def update_file(
357361
)
358362
else:
359363
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
364+
missing_items = sorted(missing.items())
360365
lines = [
361366
f"# {timestamp} - New entries automatically added by cursorless",
362-
*[create_line(key, missing[key]) for key in sorted(missing)],
367+
*[create_line(key, value) for key, value in missing_items],
363368
]
364369
with open(path, "a") as f:
365370
f.write("\n\n" + "\n".join(lines))
366371

367372
print(f"New cursorless features added to {path.name}")
368-
for key in sorted(missing):
369-
print(f"{key}: {missing[key]}")
373+
for key, value in missing_items:
374+
print(f"{key}: {value}")
370375
print(
371376
"See release notes for more info: "
372377
"https://github.com/cursorless-dev/cursorless/blob/main/CHANGELOG.md"
@@ -376,18 +381,22 @@ def update_file(
376381
return current_values
377382

378383

379-
def create_line(*cells: str):
380-
return ", ".join(cells)
381-
382-
383-
def create_file(path: Path, headers: list[str], default_values: dict):
384-
lines = [create_line(key, default_values[key]) for key in sorted(default_values)]
384+
def create_file(
385+
path: Path,
386+
headers: Sequence[str],
387+
default_values: dict[str, str],
388+
) -> None:
389+
lines = [create_line(key, value) for key, value in sorted(default_values.items())]
385390
lines.insert(0, create_line(*headers))
386391
lines.append("")
387392
path.write_text("\n".join(lines))
388393

389394

390-
def csv_error(path: Path, index: int, message: str, value: str):
395+
def create_line(*cells: str) -> str:
396+
return ", ".join(cells)
397+
398+
399+
def csv_error(path: Path, index: int, message: str, value: str) -> None:
391400
"""Check that an expected condition is true
392401
393402
Note that we try to continue reading in this case so cursorless doesn't get bricked
@@ -402,19 +411,19 @@ def csv_error(path: Path, index: int, message: str, value: str):
402411

403412
def read_file(
404413
path: Path,
405-
headers: list[str],
414+
headers: Sequence[str],
406415
default_identifiers: Container[str],
407-
extra_ignored_values: list[str],
408-
extra_allowed_values: list[str],
416+
extra_ignored_values: Sequence[str],
417+
extra_allowed_values: Sequence[str],
409418
allow_unknown_values: bool,
410-
):
419+
) -> tuple[dict[str, str], bool]:
411420
with open(path) as csv_file:
412421
# Use `skipinitialspace` to allow spaces before quote. `, "a,b"`
413422
csv_reader = csv.reader(csv_file, skipinitialspace=True)
414423
rows = list(csv_reader)
415424

416-
result = {}
417-
used_identifiers = []
425+
result: dict[str, str] = {}
426+
used_identifiers: set[str] = set()
418427
has_errors = False
419428
seen_headers = False
420429

@@ -427,7 +436,7 @@ def read_file(
427436

428437
if not seen_headers:
429438
seen_headers = True
430-
if row != headers:
439+
if row != list(headers):
431440
has_errors = True
432441
csv_error(path, i, "Malformed header", create_line(*row))
433442
print(f"Expected '{create_line(*headers)}'")
@@ -461,15 +470,15 @@ def read_file(
461470
continue
462471

463472
result[key] = value
464-
used_identifiers.append(value)
473+
used_identifiers.add(value)
465474

466475
if has_errors:
467476
app.notify("Cursorless settings error; see log")
468477

469478
return result, has_errors
470479

471480

472-
def get_full_path(filename: str):
481+
def get_full_path(filename: str) -> Path:
473482
if not filename.endswith(".csv"):
474483
filename = f"{filename}.csv"
475484

@@ -484,7 +493,7 @@ def get_full_path(filename: str):
484493
return (settings_directory / filename).resolve()
485494

486495

487-
def get_super_values(values: ListToSpokenForms):
496+
def get_super_values(values: ListToSpokenForms) -> dict[str, str]:
488497
result: dict[str, str] = {}
489498
for value_dict in values.values():
490499
result.update(value_dict)

src/spoken_forms.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from pathlib import Path
3-
from typing import Callable, Concatenate
3+
from typing import Callable, Concatenate, Sequence
44

55
from talon import app, cron, fs, registry
66

@@ -26,9 +26,14 @@
2626

2727
def auto_construct_defaults[**P, R](
2828
spoken_forms: dict[str, ListToSpokenForms],
29-
handle_new_values: Callable[[str, list[SpokenFormEntry]], None],
29+
handle_new_values: Callable[[str, Sequence[SpokenFormEntry]], None],
3030
f: Callable[
31-
Concatenate[str, ListToSpokenForms, Callable[[list[SpokenFormEntry]], None], P],
31+
Concatenate[
32+
str,
33+
ListToSpokenForms,
34+
Callable[[Sequence[SpokenFormEntry]], None],
35+
P,
36+
],
3237
R,
3338
],
3439
):
@@ -94,7 +99,7 @@ def update():
9499
initialized = False
95100

96101
# Maps from csv name to list of SpokenFormEntry
97-
custom_spoken_forms: dict[str, list[SpokenFormEntry]] = {}
102+
custom_spoken_forms: dict[str, Sequence[SpokenFormEntry]] = {}
98103
spoken_forms_output = SpokenFormsOutput()
99104
spoken_forms_output.init()
100105
graphemes_talon_list = get_graphemes_talon_list()
@@ -116,7 +121,7 @@ def update_spoken_forms_output():
116121
]
117122
)
118123

119-
def handle_new_values(csv_name: str, values: list[SpokenFormEntry]):
124+
def handle_new_values(csv_name: str, values: Sequence[SpokenFormEntry]):
120125
custom_spoken_forms[csv_name] = values
121126
if initialized:
122127
# On first run, we just do one update at the end, so we suppress
@@ -163,13 +168,13 @@ def handle_new_values(csv_name: str, values: list[SpokenFormEntry]):
163168
),
164169
handle_csv(
165170
"experimental/actions_custom.csv",
166-
headers=[SPOKEN_FORM_HEADER, "VSCode command"],
171+
headers=(SPOKEN_FORM_HEADER, "VSCode command"),
167172
allow_unknown_values=True,
168173
default_list_name="custom_action",
169174
),
170175
handle_csv(
171176
"experimental/regex_scope_types.csv",
172-
headers=[SPOKEN_FORM_HEADER, "Regex"],
177+
headers=(SPOKEN_FORM_HEADER, "Regex"),
173178
allow_unknown_values=True,
174179
default_list_name="custom_regex_scope_type",
175180
pluralize_lists=["custom_regex_scope_type"],

0 commit comments

Comments
 (0)