Skip to content

Commit d9be21a

Browse files
committed
switch to pep 604 union annotations
1 parent 29a715b commit d9be21a

25 files changed

Lines changed: 126 additions & 169 deletions

pyproject.toml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ select = [
9191
"I", # isort
9292
"UP", # pyupgrade
9393
"G", # logging
94-
"FA", # future annotations
9594
"PIE", # misc
9695
"RUF", # misc
9796
]
@@ -101,12 +100,5 @@ ignore = [
101100
"G004", # f-strings in logging statements
102101
]
103102

104-
[tool.ruff.lint.pyupgrade]
105-
# Preserve Union types, despite alternate 'X | Y' syntax being available via __future__ annotations module.
106-
# This is necessary because fastAPI and pydantic parse type annotations at runtime, and since the new syntax is
107-
# a python 3.10 feature they don't expect it in python 3.9.
108-
# This can be removed if/when we stop supporting python 3.9.
109-
keep-runtime-typing = true
110-
111103
[tool.uv]
112104
cache-keys = [{ git = { commit = true, tags = true } }]

src/cnlpt/_cli/train.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Annotated, Any, Final, Union
2+
from typing import Annotated, Any, Final
33

44
import typer
55
from click.core import ParameterSource
@@ -40,7 +40,7 @@ def callback(ctx: typer.Context, param: typer.CallbackParam, value: Any):
4040
def training_arg_option(
4141
field_name: str,
4242
*aliases,
43-
compatibility: Union[list[ModelType], None] = None,
43+
compatibility: list[ModelType] | None = None,
4444
**kwargs,
4545
):
4646
field = CnlpTrainingArguments.__dataclass_fields__[field_name]
@@ -59,7 +59,7 @@ def training_arg_option(
5959

6060
def model_arg_option(
6161
*args,
62-
compatibility: Union[list[ModelType], None] = None,
62+
compatibility: list[ModelType] | None = None,
6363
**kwargs,
6464
):
6565
if compatibility is not None:
@@ -69,7 +69,7 @@ def model_arg_option(
6969

7070
def data_arg_option(
7171
*args,
72-
compatibility: Union[list[ModelType], None] = None,
72+
compatibility: list[ModelType] | None = None,
7373
**kwargs,
7474
):
7575
if compatibility is not None:
@@ -251,15 +251,15 @@ def transformers_arg_option(field_name: str, *args, **kwargs):
251251
),
252252
]
253253
TaskNamesArg = Annotated[
254-
Union[list[str], None],
254+
list[str] | None,
255255
data_arg_option(
256256
"--task",
257257
"-t",
258258
help="The name of a task in the dataset to train on. Can be specified multiple times to target more than one task. Defaults to all tasks.",
259259
),
260260
]
261261
TokenizerArg = Annotated[
262-
Union[str, None],
262+
str | None,
263263
data_arg_option(
264264
"--tokenizer",
265265
help=f'Name or path to a model to use for tokenization. For projection and hierarchical models, this will default to the --encoder if left unspecified; otherwise defaults to "{DEFAULT_ENCODER}".',
@@ -288,15 +288,15 @@ def transformers_arg_option(field_name: str, *args, **kwargs):
288288
),
289289
]
290290
MaxTrainArg = Annotated[
291-
Union[int, None],
291+
int | None,
292292
data_arg_option("--max_train", help="Limit the number of training samples to use."),
293293
]
294294
MaxEvalArg = Annotated[
295-
Union[int, None],
295+
int | None,
296296
data_arg_option("--max_eval", help="Limit the number of eval samples to use."),
297297
]
298298
MaxTestArg = Annotated[
299-
Union[int, None],
299+
int | None,
300300
data_arg_option("--max_test", help="Limit the number of test samples to use."),
301301
]
302302
AllowDisjointLabelsArg = Annotated[
@@ -314,17 +314,17 @@ def transformers_arg_option(field_name: str, *args, **kwargs):
314314
),
315315
]
316316
HierChunkLenArg = Annotated[
317-
Union[int, None],
317+
int | None,
318318
data_arg_option("--hier_chunk_len", help="Chunk length for hierarchical models."),
319319
]
320320
HierNumChunksArg = Annotated[
321-
Union[int, None],
321+
int | None,
322322
data_arg_option(
323323
"--hier_num_chunks", help="Number of chunks for hierarchical models."
324324
),
325325
]
326326
HierPrependEmptyChunkArg = Annotated[
327-
Union[int, None],
327+
int | None,
328328
data_arg_option(
329329
"--hier_prepend_empty_chunk",
330330
help="Whether to prepend an empty chunk for hierarchical models.",
@@ -349,23 +349,19 @@ def transformers_arg_option(field_name: str, *args, **kwargs):
349349
"logging_first_step", "--logging_first_step/--no_logging_first_step"
350350
),
351351
]
352-
CacheDirArg = Annotated[Union[str, None], training_arg_option("cache_dir")]
352+
CacheDirArg = Annotated[str | None, training_arg_option("cache_dir")]
353353
MetricForBestModelArg = Annotated[str, training_arg_option("metric_for_best_model")]
354354

355355

356356
##### COMMON HF TRANSFORMERS ARGS #####
357-
NumTrainEpochsArg = Annotated[
358-
Union[float, None], transformers_arg_option("num_train_epochs")
359-
]
357+
NumTrainEpochsArg = Annotated[float | None, transformers_arg_option("num_train_epochs")]
360358
PerDeviceTrainBatchSizeArg = Annotated[
361-
Union[int, None], transformers_arg_option("per_device_train_batch_size")
359+
int | None, transformers_arg_option("per_device_train_batch_size")
362360
]
363361
GradientAccumulationStepsArg = Annotated[
364-
Union[int, None], transformers_arg_option("gradient_accumulation_steps")
365-
]
366-
LearningRateArg = Annotated[
367-
Union[float, None], transformers_arg_option("learning_rate")
362+
int | None, transformers_arg_option("gradient_accumulation_steps")
368363
]
364+
LearningRateArg = Annotated[float | None, transformers_arg_option("learning_rate")]
369365
DoTrainArg = Annotated[bool, transformers_arg_option("do_train", "--do_train")]
370366
DoEvalArg = Annotated[bool, transformers_arg_option("do_eval", "--do_eval")]
371367
DoPredictArg = Annotated[bool, transformers_arg_option("do_predict", "--do_predict")]
@@ -613,7 +609,7 @@ def train(
613609
if bias_fit:
614610
model_init_kwargs["bias_fit"] = True
615611

616-
model: Union[CnnModel, LstmModel, HierarchicalModel, ProjectionModel] = (
612+
model: CnnModel | LstmModel | HierarchicalModel | ProjectionModel = (
617613
AutoModel.from_config(config, **model_init_kwargs)
618614
)
619615
train_system = CnlpTrainSystem(model, dataset, training_args)

src/cnlpt/data/cnlp_dataset.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import Counter
33
from dataclasses import dataclass
44
from enum import Enum
5-
from typing import Literal, Union
5+
from typing import Literal
66

77
import torch
88
from datasets import Dataset
@@ -22,7 +22,7 @@ class HierarchicalDataConfig:
2222

2323
def load_tokenizer(
2424
model_name_or_path: str,
25-
hf_cache_dir: Union[str, None] = None,
25+
hf_cache_dir: str | None = None,
2626
truncation_side: Literal["left", "right"] = "right",
2727
character_level: bool = False,
2828
) -> PreTrainedTokenizer:
@@ -50,19 +50,19 @@ class CnlpDataset:
5050

5151
def __init__(
5252
self,
53-
data_dir: Union[str, os.PathLike],
54-
tokenizer: Union[str, PreTrainedTokenizer] = "roberta-base",
55-
task_names: Union[list[str], None] = None,
56-
hier_config: Union[HierarchicalDataConfig, None] = None,
53+
data_dir: str | os.PathLike,
54+
tokenizer: str | PreTrainedTokenizer = "roberta-base",
55+
task_names: list[str] | None = None,
56+
hier_config: HierarchicalDataConfig | None = None,
5757
truncation_side: TruncationSide = TruncationSide.RIGHT,
5858
max_seq_length: int = 128,
5959
use_data_cache: bool = True,
60-
max_train: Union[int, None] = None,
61-
max_eval: Union[int, None] = None,
62-
max_test: Union[int, None] = None,
60+
max_train: int | None = None,
61+
max_eval: int | None = None,
62+
max_test: int | None = None,
6363
allow_disjoint_labels: bool = False,
6464
character_level: bool = False,
65-
hf_cache_dir: Union[str, None] = None,
65+
hf_cache_dir: str | None = None,
6666
):
6767
"""Create a new `CnlpDataset`.
6868

src/cnlpt/data/data_reader.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import os
33
from collections.abc import Iterable
4-
from typing import Any, Final, Literal, Union, cast
4+
from typing import Any, Final, Literal, cast
55

66
from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
77

@@ -21,7 +21,7 @@
2121
NONE_VALUE: Final = "__None__"
2222

2323

24-
def _infer_split(filepath: Union[str, os.PathLike]) -> DatasetSplit:
24+
def _infer_split(filepath: str | os.PathLike) -> DatasetSplit:
2525
_dir, filename = os.path.split(filepath)
2626
root, _ext = os.path.splitext(filename)
2727

@@ -106,7 +106,7 @@ def _get_task_by_name(self, task_name: str):
106106
return task
107107
raise ValueError(f'task with name "{task_name}" not found')
108108

109-
def get_tasks(self, task_names: Union[Iterable[str], None] = None):
109+
def get_tasks(self, task_names: Iterable[str] | None = None):
110110
"""Get all or some subset of the tasks in the data.
111111
112112
The `TaskInfo` objects returned by this method will have their `index` property
@@ -198,8 +198,8 @@ def _extend(self, new_dataset: DatasetDict, tasks: list[TaskInfo]):
198198

199199
def load_json(
200200
self,
201-
json_filepath: Union[str, os.PathLike],
202-
split: Union[DatasetSplit, None] = None,
201+
json_filepath: str | os.PathLike,
202+
split: DatasetSplit | None = None,
203203
):
204204
"""Update this reader with new data from a CNLP-formatted json file.
205205
@@ -274,8 +274,8 @@ def load_json(
274274

275275
def load_csv(
276276
self,
277-
csv_filepath: Union[str, os.PathLike],
278-
split: Union[DatasetSplit, None] = None,
277+
csv_filepath: str | os.PathLike,
278+
split: DatasetSplit | None = None,
279279
sep: str = ",",
280280
):
281281
"""Update this reader with new data from a CNLP-formatted csv (or tsv) file.
@@ -299,7 +299,7 @@ def load_csv(
299299
tasks = _infer_tasks(dataset[split])
300300
self._extend(dataset, tasks)
301301

302-
def load_dir(self, data_dir: Union[str, os.PathLike]):
302+
def load_dir(self, data_dir: str | os.PathLike):
303303
"""Update this reader with new data from a directory containing CNLP-formatted data.
304304
305305
This will search (non-recursively) for files named "train", "test", "validation", "valid", or "dev",

src/cnlpt/data/predictions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from collections.abc import Iterable
44
from dataclasses import asdict, dataclass
5-
from typing import Any, Union
5+
from typing import Any
66

77
import numpy as np
88
import numpy.typing as npt
@@ -19,7 +19,7 @@
1919
class TaskPredictions:
2020
task: TaskInfo
2121
logits: npt.NDArray
22-
labels: Union[npt.NDArray, None]
22+
labels: npt.NDArray | None
2323

2424
@property
2525
def probs(self) -> npt.NDArray:
@@ -34,7 +34,7 @@ def predicted_str_labels(self) -> npt.NDArray:
3434
return np.array(self.task.labels)[self.predicted_int_labels]
3535

3636
@property
37-
def target_str_labels(self) -> Union[npt.NDArray, None]:
37+
def target_str_labels(self) -> npt.NDArray | None:
3838
if self.labels is None:
3939
return None
4040
masked = self.labels.copy()
@@ -68,7 +68,7 @@ def __init__(
6868

6969
self.task_predictions: dict[str, TaskPredictions] = {}
7070

71-
task_labels: dict[str, Union[npt.NDArray, None]]
71+
task_labels: dict[str, npt.NDArray | None]
7272

7373
if self.raw.label_ids is None:
7474
task_labels = {t.name: None for t in tasks}
@@ -137,7 +137,7 @@ def arr_to_list(obj):
137137

138138
def save_json(
139139
self,
140-
json_filepath: Union[str, os.PathLike],
140+
json_filepath: str | os.PathLike,
141141
allow_overwrite: bool = False,
142142
):
143143
write_mode = "w" if allow_overwrite else "x"
@@ -169,7 +169,7 @@ def list_to_arr(obj, dtype):
169169
)
170170

171171
@classmethod
172-
def load_json(cls, filepath: Union[str, os.PathLike]):
172+
def load_json(cls, filepath: str | os.PathLike):
173173
with open(filepath) as f:
174174
return cls.from_dict(json.load(f))
175175

src/cnlpt/data/preprocess.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from collections.abc import Iterable
3-
from typing import TYPE_CHECKING, Any, Final, Union
3+
from typing import TYPE_CHECKING, Any, Final
44

55
import numpy as np
66
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -17,13 +17,13 @@
1717

1818

1919
def preprocess_raw_data(
20-
batch: dict[str, Union[list[str], list[int], list[float]]],
20+
batch: dict[str, list[str] | list[int] | list[float]],
2121
tokenizer: PreTrainedTokenizer,
22-
tasks: Union[Iterable[TaskInfo], None],
23-
max_length: Union[int, None] = None,
22+
tasks: Iterable[TaskInfo] | None,
23+
max_length: int | None = None,
2424
inference_only: bool = False,
2525
character_level: bool = False,
26-
hier_config: Union["HierarchicalDataConfig", None] = None,
26+
hier_config: "HierarchicalDataConfig | None" = None,
2727
) -> BatchEncoding:
2828
"""Preprocess raw CNLP data for training/evaluation.
2929
@@ -248,7 +248,7 @@ def _get_word_ids(
248248
tokenizer: PreTrainedTokenizer,
249249
tokenized_input: BatchEncoding,
250250
character_level: bool,
251-
) -> list[list[Union[int, None]]]:
251+
) -> list[list[int | None]]:
252252
if tokenizer.is_fast:
253253
return [
254254
tokenized_input.word_ids(i) for i in range(len(tokenized_input.input_ids))
@@ -267,9 +267,9 @@ def _get_word_ids(
267267
]
268268
)
269269

270-
def get_word_ids(indices: Iterable[int]) -> list[Union[int, None]]:
270+
def get_word_ids(indices: Iterable[int]) -> list[int | None]:
271271
current = 0
272-
raw: list[Union[int, None]] = []
272+
raw: list[int | None] = []
273273
for index in indices:
274274
if index in special_token_ids:
275275
raw.append(None)
@@ -290,9 +290,9 @@ def get_word_ids(indices: Iterable[int]) -> list[Union[int, None]]:
290290

291291

292292
def _tokenize_batch(
293-
batch: dict[str, Union[list[str], list[int], list[float]]],
293+
batch: dict[str, list[str] | list[int] | list[float]],
294294
tokenizer: PreTrainedTokenizer,
295-
max_length: Union[int, None],
295+
max_length: int | None,
296296
hierarchical: bool,
297297
character_level: bool,
298298
) -> BatchEncoding:
@@ -339,9 +339,7 @@ def _tokenize_batch(
339339
return tokenized_batch
340340

341341

342-
def _preprocess_raw_labels(
343-
raw: Union[list[str], list[int], list[float]], task: TaskInfo
344-
):
342+
def _preprocess_raw_labels(raw: list[str] | list[int] | list[float], task: TaskInfo):
345343
mask_missing: Final = {MISSING_DATA_STR: MASK_VALUE}
346344
if task.type == CLASSIFICATION:
347345
# labels is just a list of one label for each instance
@@ -356,7 +354,7 @@ def _preprocess_raw_labels(
356354
for tags in raw
357355
]
358356
elif task.type == RELATIONS:
359-
preprocessed: list[Union[list[str], list[tuple[int, int, int]]]] = []
357+
preprocessed: list[list[str] | list[tuple[int, int, int]]] = []
360358
for relations in raw:
361359
if relations in (None, "None"):
362360
preprocessed.append(["None"])
@@ -483,7 +481,7 @@ def _build_labels_for_task(
483481
labels: list[tuple[Any, ...]],
484482
max_length: int,
485483
pad_classification: bool,
486-
) -> Union[np.ndarray, list[np.ndarray]]:
484+
) -> np.ndarray | list[np.ndarray]:
487485
if task.type == TAGGING:
488486
return _get_tagging_labels(task, tokenized_input, labels)
489487
elif task.type == RELATIONS:

0 commit comments

Comments
 (0)