11from enum import Enum
2- from typing import Annotated , Any , Final , Union
2+ from typing import Annotated , Any , Final
33
44import typer
55from click .core import ParameterSource
@@ -40,7 +40,7 @@ def callback(ctx: typer.Context, param: typer.CallbackParam, value: Any):
4040def training_arg_option (
4141 field_name : str ,
4242 * aliases ,
43- compatibility : Union [ list [ModelType ], None ] = None ,
43+ compatibility : list [ModelType ] | None = None ,
4444 ** kwargs ,
4545):
4646 field = CnlpTrainingArguments .__dataclass_fields__ [field_name ]
@@ -59,7 +59,7 @@ def training_arg_option(
5959
6060def model_arg_option (
6161 * args ,
62- compatibility : Union [ list [ModelType ], None ] = None ,
62+ compatibility : list [ModelType ] | None = None ,
6363 ** kwargs ,
6464):
6565 if compatibility is not None :
@@ -69,7 +69,7 @@ def model_arg_option(
6969
7070def data_arg_option (
7171 * args ,
72- compatibility : Union [ list [ModelType ], None ] = None ,
72+ compatibility : list [ModelType ] | None = None ,
7373 ** kwargs ,
7474):
7575 if compatibility is not None :
@@ -251,15 +251,15 @@ def transformers_arg_option(field_name: str, *args, **kwargs):
251251 ),
252252]
253253TaskNamesArg = Annotated [
254- Union [ list [str ], None ] ,
254+ list [str ] | None ,
255255 data_arg_option (
256256 "--task" ,
257257 "-t" ,
258258 help = "The name of a task in the dataset to train on. Can be specified multiple times to target more than one task. Defaults to all tasks." ,
259259 ),
260260]
261261TokenizerArg = Annotated [
262- Union [ str , None ] ,
262+ str | None ,
263263 data_arg_option (
264264 "--tokenizer" ,
265265 help = f'Name or path to a model to use for tokenization. For projection and hierarchical models, this will default to the --encoder if left unspecified; otherwise defaults to "{ DEFAULT_ENCODER } ".' ,
@@ -288,15 +288,15 @@ def transformers_arg_option(field_name: str, *args, **kwargs):
288288 ),
289289]
290290MaxTrainArg = Annotated [
291- Union [ int , None ] ,
291+ int | None ,
292292 data_arg_option ("--max_train" , help = "Limit the number of training samples to use." ),
293293]
294294MaxEvalArg = Annotated [
295- Union [ int , None ] ,
295+ int | None ,
296296 data_arg_option ("--max_eval" , help = "Limit the number of eval samples to use." ),
297297]
298298MaxTestArg = Annotated [
299- Union [ int , None ] ,
299+ int | None ,
300300 data_arg_option ("--max_test" , help = "Limit the number of test samples to use." ),
301301]
302302AllowDisjointLabelsArg = Annotated [
@@ -314,17 +314,17 @@ def transformers_arg_option(field_name: str, *args, **kwargs):
314314 ),
315315]
316316HierChunkLenArg = Annotated [
317- Union [ int , None ] ,
317+ int | None ,
318318 data_arg_option ("--hier_chunk_len" , help = "Chunk length for hierarchical models." ),
319319]
320320HierNumChunksArg = Annotated [
321- Union [ int , None ] ,
321+ int | None ,
322322 data_arg_option (
323323 "--hier_num_chunks" , help = "Number of chunks for hierarchical models."
324324 ),
325325]
326326HierPrependEmptyChunkArg = Annotated [
327- Union [ int , None ] ,
327+ int | None ,
328328 data_arg_option (
329329 "--hier_prepend_empty_chunk" ,
330330 help = "Whether to prepend an empty chunk for hierarchical models." ,
@@ -349,23 +349,19 @@ def transformers_arg_option(field_name: str, *args, **kwargs):
349349 "logging_first_step" , "--logging_first_step/--no_logging_first_step"
350350 ),
351351]
352- CacheDirArg = Annotated [Union [ str , None ] , training_arg_option ("cache_dir" )]
352+ CacheDirArg = Annotated [str | None , training_arg_option ("cache_dir" )]
353353MetricForBestModelArg = Annotated [str , training_arg_option ("metric_for_best_model" )]
354354
355355
356356##### COMMON HF TRANSFORMERS ARGS #####
357- NumTrainEpochsArg = Annotated [
358- Union [float , None ], transformers_arg_option ("num_train_epochs" )
359- ]
357+ NumTrainEpochsArg = Annotated [float | None , transformers_arg_option ("num_train_epochs" )]
360358PerDeviceTrainBatchSizeArg = Annotated [
361- Union [ int , None ] , transformers_arg_option ("per_device_train_batch_size" )
359+ int | None , transformers_arg_option ("per_device_train_batch_size" )
362360]
363361GradientAccumulationStepsArg = Annotated [
364- Union [int , None ], transformers_arg_option ("gradient_accumulation_steps" )
365- ]
366- LearningRateArg = Annotated [
367- Union [float , None ], transformers_arg_option ("learning_rate" )
362+ int | None , transformers_arg_option ("gradient_accumulation_steps" )
368363]
364+ LearningRateArg = Annotated [float | None , transformers_arg_option ("learning_rate" )]
369365DoTrainArg = Annotated [bool , transformers_arg_option ("do_train" , "--do_train" )]
370366DoEvalArg = Annotated [bool , transformers_arg_option ("do_eval" , "--do_eval" )]
371367DoPredictArg = Annotated [bool , transformers_arg_option ("do_predict" , "--do_predict" )]
@@ -613,7 +609,7 @@ def train(
613609 if bias_fit :
614610 model_init_kwargs ["bias_fit" ] = True
615611
616- model : Union [ CnnModel , LstmModel , HierarchicalModel , ProjectionModel ] = (
612+ model : CnnModel | LstmModel | HierarchicalModel | ProjectionModel = (
617613 AutoModel .from_config (config , ** model_init_kwargs )
618614 )
619615 train_system = CnlpTrainSystem (model , dataset , training_args )
0 commit comments