From 263deeece9e1d75813b7c0292e4b5e77638605c8 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 27 May 2026 16:31:51 -0700 Subject: [PATCH] Use Python 3.10+ typing in codegen scripts (#5787) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/5787 Reviewed By: henrylhtsang Differential Revision: D106570085 Pulled By: q10 --- .../genscript/generate_backward_split.py | 13 +- .../genscript/generate_forward_quantized.py | 5 +- .../genscript/generate_forward_split.py | 9 +- .../codegen/genscript/jinja_environment.py | 23 ++-- .../codegen/genscript/optimizer_args.py | 129 +++++++++--------- fbgemm_gpu/codegen/genscript/optimizers.py | 36 ++--- .../codegen/genscript/scripts_argsparse.py | 3 +- 7 files changed, 107 insertions(+), 111 deletions(-) diff --git a/fbgemm_gpu/codegen/genscript/generate_backward_split.py b/fbgemm_gpu/codegen/genscript/generate_backward_split.py index 50506decb1..257abe5a36 100644 --- a/fbgemm_gpu/codegen/genscript/generate_backward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_backward_split.py @@ -11,7 +11,6 @@ import itertools import sys from copy import deepcopy -from typing import List try: # pyre-fixme[21]: Could not find name `ArgType` in @@ -45,7 +44,7 @@ def render_backward_templates( template_filepath: str, optimizer: str, filename_format: str, - kwargs: Dict[str, Any], + kwargs: dict[str, Any], is_gwd: bool = False, ) -> None: if not kwargs.get("has_gpu_support"): @@ -352,14 +351,14 @@ def generate_rocm_backward_split(**kwargs: Any) -> None: @staticmethod def generate_backward_header( - aux_args: Dict[str, List[str]], aux_names: List[str], is_ssd: bool = False + aux_args: dict[str, list[str]], aux_names: list[str], is_ssd: bool = False ) -> None: """ Generate a header file that contains enum of argument order from the dict Parameters: - aux_args (Dict[str, List[str]]): a dict containing a list of arguments - aux_names (List[str]): names of the argument types (e.g. aux_tensor, aux_int, etc.) + aux_args (dict[str, list[str]]): a dict containing a list of arguments + aux_names (list[str]): names of the argument types (e.g. aux_tensor, aux_int, etc.) Return: None """ @@ -372,7 +371,7 @@ def generate_backward_header( @staticmethod def generate_python_sources( - all_optimizers: List[str], ssd_optimizers: List[str] + all_optimizers: list[str], ssd_optimizers: list[str] ) -> None: CodeTemplate.load("training/python/__init__.template").write( "__init__.py", all_optimizers=all_optimizers, ssd_optimizers=ssd_optimizers @@ -418,7 +417,7 @@ def generate() -> None: # This is a dict of auxilary arguments used in TBE PT2 interface where the aux # arguments of a type are packed into a list for that type. This dict maintains the # order of the arguments of each type. - aux_args: Dict[str, List[str]] = { + aux_args: dict[str, list[str]] = { "aux_tensor": [ "B_offsets", # 0 "vbe_output_offsets_feature_rank", # 1 diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_quantized.py b/fbgemm_gpu/codegen/genscript/generate_forward_quantized.py index d37d25d1fb..16a7a4bda0 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_quantized.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_quantized.py @@ -9,7 +9,6 @@ import sys from dataclasses import dataclass -from typing import Dict, List try: from .common import CodeTemplate @@ -31,7 +30,7 @@ class ElemType: cpp_type_name: str primitive_type: str bit_width: int - template_params: List[TemplateParams] + template_params: list[TemplateParams] @property def enum_name(self) -> str: @@ -122,7 +121,7 @@ def enum_name(self) -> str: ), ] -ELEM_TYPES_MAP: Dict[str, ElemType] = {etype.enum_name: etype for etype in ELEM_TYPES} +ELEM_TYPES_MAP: dict[str, ElemType] = {etype.enum_name: etype for etype in ELEM_TYPES} class ForwardQuantizedGenerator: diff --git a/fbgemm_gpu/codegen/genscript/generate_forward_split.py b/fbgemm_gpu/codegen/genscript/generate_forward_split.py index d99afa09a0..6f14422826 100644 --- a/fbgemm_gpu/codegen/genscript/generate_forward_split.py +++ b/fbgemm_gpu/codegen/genscript/generate_forward_split.py @@ -10,7 +10,6 @@ import itertools import sys -from typing import List try: from .common import CodeTemplate @@ -28,10 +27,10 @@ class ForwardSplitGenerator: def render_forward_templates( template_filepath: str, filename_format: str, - dense_options: List[bool], - nobag_options: List[bool], - vbe_options: List[bool], - ssd_options: List[bool], + dense_options: list[bool], + nobag_options: list[bool], + vbe_options: list[bool], + ssd_options: list[bool], is_gwd: bool = False, ) -> None: template = CodeTemplate.load(template_filepath) diff --git a/fbgemm_gpu/codegen/genscript/jinja_environment.py b/fbgemm_gpu/codegen/genscript/jinja_environment.py index 37e658f5a9..e3aaec473b 100644 --- a/fbgemm_gpu/codegen/genscript/jinja_environment.py +++ b/fbgemm_gpu/codegen/genscript/jinja_environment.py @@ -11,7 +11,6 @@ import argparse import os import re -from typing import Dict, List, Optional, Tuple import jinja2 @@ -81,7 +80,7 @@ ################################################################################ -def prepare_string_for_formatting(blob: str, format_keywords: List[str]) -> str: +def prepare_string_for_formatting(blob: str, format_keywords: list[str]) -> str: """ Replace curly brackets ('{' or '}') with escape characters ('{{' or '}}') to prepare the string to be formatted by `str.format()`. `str.format()` @@ -95,7 +94,7 @@ def prepare_string_for_formatting(blob: str, format_keywords: List[str]) -> str: def generate_optimized_grad_sum_loop_access( - blob: str, other_formats: Optional[Dict[str, str]] = None + blob: str, other_formats: dict[str, str] | None = None ) -> str: """ Generate an optimized code for grad_sum accessing @@ -148,13 +147,13 @@ def get_max_vecs_template_configs( fixed_max_vecs_per_thread: int, use_subwarp_shuffle: bool, use_vec_blocking: bool, -) -> List[Tuple[int, int, str]]: +) -> list[tuple[int, int, str]]: """ Generate the template configs for each kFixedMaxVecsPerThread, kThreadGroupSize, and kUseVecBlocking """ warp_size = items_per_warp // 4 - configs: List[Tuple[int, int, str]] = [] + configs: list[tuple[int, int, str]] = [] if use_vec_blocking: # kFixedMaxVecsPerThread = fixed_max_vecs_per_thread @@ -361,7 +360,7 @@ def compute_global_weight_decay(is_global_weight_decay_kernel: bool) -> str: # Format the macro call to generate pta::PackedTensorAccessors -def make_pta_acc_format(pta_str_list: List[str], func_name: str) -> List[str]: +def make_pta_acc_format(pta_str_list: list[str], func_name: str) -> list[str]: new_str_list = [] for pta_str in pta_str_list: if "packed_accessor" in pta_str: @@ -387,7 +386,7 @@ def make_pta_acc_format(pta_str_list: List[str], func_name: str) -> List[str]: return new_str_list -def make_pta_acc_builder_format(pta_str_list: List[str]) -> List[str]: +def make_pta_acc_builder_format(pta_str_list: list[str]) -> list[str]: new_str_list = [] for pta_str in pta_str_list: if "packed_accessor" in pta_str: @@ -411,7 +410,7 @@ def make_pta_acc_builder_format(pta_str_list: List[str]) -> List[str]: return new_str_list -def replace_pta_namespace(pta_str_list: List[str]) -> List[str]: +def replace_pta_namespace(pta_str_list: list[str]) -> list[str]: return [ pta_str.replace("at::PackedTensorAccessor", "pta::PackedTensorAccessor") for pta_str in pta_str_list @@ -420,10 +419,10 @@ def replace_pta_namespace(pta_str_list: List[str]) -> List[str]: def replace_placeholder_types( # pyre-fixme[11]: Annotation `TensorType` is not defined as a type. - arg_str_list: List[str], + arg_str_list: list[str], # pyre-fixme[11]: Annotation `TensorType` is not defined as a type. - type_combo: Optional[Dict[str, TensorType]], -) -> List[str]: + type_combo: dict[str, TensorType] | None, +) -> list[str]: """ Replace the placeholder types with the primitive types """ @@ -439,7 +438,7 @@ def replace_placeholder_types( return new_str_list -def to_upper_placeholder_types(arg_str_list: List[str]) -> List[str]: +def to_upper_placeholder_types(arg_str_list: list[str]) -> list[str]: """ Make the placeholder type names upper cases """ diff --git a/fbgemm_gpu/codegen/genscript/optimizer_args.py b/fbgemm_gpu/codegen/genscript/optimizer_args.py index 2c8441d1fe..7e1cba9464 100644 --- a/fbgemm_gpu/codegen/genscript/optimizer_args.py +++ b/fbgemm_gpu/codegen/genscript/optimizer_args.py @@ -12,8 +12,9 @@ import itertools +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any try: @@ -34,8 +35,8 @@ class OptimizerArgsSetItem: # pyre-fixme[11]: Annotation `ArgType` is not defined as a type. ty: ArgType # type name: str - default: Union[float, ArgType] = 0 # DEFAULT_ARG_VAL - ph_tys: Optional[List[ArgType]] = None # placeholder types + default: float | ArgType = 0 # DEFAULT_ARG_VAL + ph_tys: list[ArgType] | None = None # placeholder types is_optional: bool = False # optional variable @@ -48,7 +49,7 @@ class OptimizerArgsSetItem: ###################################################################### # a dict of tensor name and annotation to mark whether the tensor is mutable. # this is use to annotate the tensor in the defintion schema. -annotation_dict: Dict[str, str] = { +annotation_dict: dict[str, str] = { "weights": "(a!)", "weights_host": "(a!)", "weights_dev": "(b!)", @@ -302,7 +303,7 @@ def schema_optional_tensorlist_arg(name: str) -> str: def make_kernel_arg( ty: ArgType, name: str, - default: Union[int, float, None], + default: int | float | None, pass_by_ref: bool = False, ) -> str: return { @@ -348,7 +349,7 @@ def make_kernel_arg_constructor(ty: ArgType, name: str) -> str: }[ty](name) -def make_cpu_kernel_arg(ty: ArgType, name: str, default: Union[int, float]) -> str: +def make_cpu_kernel_arg(ty: ArgType, name: str, default: int | float) -> str: return { ArgType.TENSOR: lambda x: acc_cache_tensor_arg(x, gpu=False), ArgType.INT_TENSOR: lambda x: int_tensor_arg(x, gpu=False), @@ -379,7 +380,7 @@ def make_cpu_kernel_arg_constructor(ty: ArgType, name: str) -> str: def make_function_arg( ty: ArgType, name: str, - default: Optional[Union[int, float]], + default: int | float | None, is_optional: bool = False, ) -> str: return { @@ -442,7 +443,7 @@ def make_function_arg( }[ty](name) -def make_function_schema_arg(ty: ArgType, name: str, default: Union[int, float]) -> str: +def make_function_schema_arg(ty: ArgType, name: str, default: int | float) -> str: return { ArgType.TENSOR: tensor_arg_annotate, ArgType.INT_TENSOR: tensor_arg, @@ -450,7 +451,7 @@ def make_function_schema_arg(ty: ArgType, name: str, default: Union[int, float]) ArgType.PLACEHOLDER_TENSOR: tensor_arg, ArgType.INT: lambda x: int_arg(x, default=int(default)), ArgType.FLOAT: lambda x: float_arg(x, default=default), - # pyre-fixme[6]: For 2nd argument expected `int` but got `Union[float, int]`. + # pyre-fixme[6]: For 2nd argument expected `int` but got `float | int`. ArgType.SYM_INT: lambda x: schema_sym_int_arg(x, default=default), ArgType.BOOL: lambda x: schema_bool_arg(x, default=bool(default)), }[ty](name) @@ -540,13 +541,13 @@ def make_ivalue_cast(ty: ArgType) -> str: }[ty] -def reorder_args(split_arg_spec: List[OptimItem]) -> List[OptimItem]: +def reorder_args(split_arg_spec: list[OptimItem]) -> list[OptimItem]: """ Reorder such that tensor arguments come first. This is used in backend, wrapper and kernels where tensors are no longer optional. We need to pass tensor arguments before other types which have default arguments. Parameters: - split_arg_spec (List[OptimItem]): List of argument items + split_arg_spec (list[OptimItem]): List of argument items Return: reordered of split_arg_spec @@ -569,20 +570,20 @@ def reorder_args(split_arg_spec: List[OptimItem]) -> List[OptimItem]: @dataclass class PT2ArgsSet: - split_function_args: List[str] - split_function_arg_names: List[str] - split_function_schemas: List[str] - split_saved_tensorlist: List[str] - split_saved_tensorlist_optional: List[str] - split_saved_data: List[dict[str, str]] - split_variables: List[str] - split_unpacked_arg_names: List[str] - split_args_dict: Dict[str, List[str]] + split_function_args: list[str] + split_function_arg_names: list[str] + split_function_schemas: list[str] + split_saved_tensorlist: list[str] + split_saved_tensorlist_optional: list[str] + split_saved_data: list[dict[str, str]] + split_variables: list[str] + split_unpacked_arg_names: list[str] + split_args_dict: dict[str, list[str]] @staticmethod # pyre-ignore[3] def create( - arg_spec: List[OptimItem], + arg_spec: list[OptimItem], ): """ PT2ArgsSet.create() is a method that creates different formats given the optimization arguments @@ -593,25 +594,25 @@ def create( e.g., instead of passing `momentum_host, `momentum_dev`, etc, we pass `momentum` Parameters: - arg_spec: List[OptimItem] - list of argument specs + arg_spec: list[OptimItem] - list of argument specs Returns: PT2ArgsSet object with the following attributes: - split_function_args: List[str] - List of function arguments used in unified lookup and autograd functions + split_function_args: list[str] - List of function arguments used in unified lookup and autograd functions Tensors will be packed and pass as TensorList. Auxillary arguments will be packed in dict. e.g., ['at::TensorList momentum1', 'at::Dict optim_int']. - split_function_arg_names: List[str] - List of argument names used in unified lookup and autograd functions + split_function_arg_names: list[str] - List of argument names used in unified lookup and autograd functions e.g., ['momentum1', 'optim_int', 'optim_float']. - split_function_schemas: List[str] - List of arguments used in unified lookup and autograd functions in the schema format + split_function_schemas: list[str] - List of arguments used in unified lookup and autograd functions in the schema format e.g., ['Tensor[] momentum1', 'float eps', 'float weight_decay']. - split_saved_tensorlist: List[str] - List of tensor names that are packed into tensorlist and will be unpacked in + split_saved_tensorlist: list[str] - List of tensor names that are packed into tensorlist and will be unpacked in PT2 autograd function. e.g., ['momentum1']. - split_saved_tensorlist_optional: List[str] - List of tensor names that are packed into tensorlist but are optional + split_saved_tensorlist_optional: list[str] - List of tensor names that are packed into tensorlist but are optional and will be unpacked in PT2 autograd function e.g., ['row_counter']. - split_saved_data: List[dict[str, str]] - List of non-tensor arguments that are saved for backward - split_unpacked_arg_names: List[str] - List of argument names, unrolled from list + split_saved_data: list[dict[str, str]] - List of non-tensor arguments that are saved for backward + split_unpacked_arg_names: list[str] - List of argument names, unrolled from list e.g., ['momentum1', 'eps', 'weight_decay', 'iter']. - split_args_dict: Dict[str, List[str]] - Dict of optim arguments' types containing the argument names of that type. + split_args_dict: dict[str, list[str]] - Dict of optim arguments' types containing the argument names of that type. e.g., if an optimizer only has an int argument called iter, the dict will look like: {'optim_tensor': [], 'optim_int': ['iter'], 'optim_float': [], 'optim_bool': []} """ @@ -635,7 +636,7 @@ def create( } # list of symint args to be appended after optim_xxx args # since they have default values - symint_list: List[OptimItem] = [] + symint_list: list[OptimItem] = [] for s in arg_spec: if s.name == "learning_rate_tensor": @@ -780,39 +781,39 @@ def append_lists(type_name: str) -> None: @dataclass class OptimizerArgs: - split_kernel_args: List[str] - split_kernel_args_no_defaults: List[str] - split_kernel_arg_constructors: List[str] - split_cpu_kernel_args: List[str] - split_cpu_kernel_arg_constructors: List[str] - split_function_args: List[str] - split_function_args_no_defaults: List[str] - split_saved_tensors: List[str] - split_tensors: List[str] - split_tensor_types: Dict[str, str] - saved_data: List[Tuple[str, str]] - split_function_arg_names: List[str] - split_function_schemas: List[str] - split_variables: List[str] - split_ref_kernel_args: List[str] - placeholder_tensor_names: List[str] + split_kernel_args: list[str] + split_kernel_args_no_defaults: list[str] + split_kernel_arg_constructors: list[str] + split_cpu_kernel_args: list[str] + split_cpu_kernel_arg_constructors: list[str] + split_function_args: list[str] + split_function_args_no_defaults: list[str] + split_saved_tensors: list[str] + split_tensors: list[str] + split_tensor_types: dict[str, str] + saved_data: list[tuple[str, str]] + split_function_arg_names: list[str] + split_function_schemas: list[str] + split_variables: list[str] + split_ref_kernel_args: list[str] + placeholder_tensor_names: list[str] # pyre-fixme[11]: Annotation `TensorType` is not defined as a type. - placeholder_type_combos: Union[List[Dict[str, TensorType]], List[None]] + placeholder_type_combos: list[dict[str, TensorType]] | list[None] unified_pt2: PT2ArgsSet - split_kernel_arg_names: List[str] - split_function_args_autograd: List[str] - split_function_arg_names_autograd: List[str] - split_saved_tensors_optional: List[str] - split_function_args_v1: Optional[str] = None - split_function_schemas_v1: Optional[str] = None + split_kernel_arg_names: list[str] + split_function_args_autograd: list[str] + split_function_arg_names_autograd: list[str] + split_saved_tensors_optional: list[str] + split_function_args_v1: str | None = None + split_function_schemas_v1: str | None = None @staticmethod # pyre-ignore[3] def create( - split_arg_spec: List[OptimItem], - arg_spec: List[OptimItem], + split_arg_spec: list[OptimItem], + arg_spec: list[OptimItem], gpu: bool, - additional_spec: Optional[dict[str, Any]] = None, + additional_spec: dict[str, Any] | None = None, ): # Keep the argument order for forward/backward compatibility # Arg order: non-optional tensors, learning_rate_tensor, non-tensors, optional tensors @@ -1014,10 +1015,10 @@ class OptimizerArgsSet: @staticmethod def create_optim_args( - arg_spec: List[OptimItem], - ext_fn: Callable[[OptimItem], List[OptimItem]], + arg_spec: list[OptimItem], + ext_fn: Callable[[OptimItem], list[OptimItem]], gpu: bool, - additional_spec: Optional[dict[str, Any]] = None, + additional_spec: dict[str, Any] | None = None, ) -> OptimizerArgs: split_arg_spec = [] for s in arg_spec: @@ -1035,7 +1036,7 @@ def create_optim_args( return OptimizerArgs.create(split_arg_spec, arg_spec, gpu, additional_spec) @staticmethod - def extend_for_cpu(spec: OptimItem) -> List[OptimItem]: + def extend_for_cpu(spec: OptimItem) -> list[OptimItem]: name = spec.name default = spec.default is_optional = spec.is_optional @@ -1056,7 +1057,7 @@ def extend_for_cpu(spec: OptimItem) -> List[OptimItem]: ] @staticmethod - def extend_for_cuda(spec: OptimItem) -> List[OptimItem]: + def extend_for_cuda(spec: OptimItem) -> list[OptimItem]: name = spec.name default = spec.default ty = spec.ty @@ -1081,7 +1082,7 @@ def extend_for_cuda(spec: OptimItem) -> List[OptimItem]: ] @staticmethod - def extend_for_any(spec: OptimItem) -> List[OptimItem]: + def extend_for_any(spec: OptimItem) -> list[OptimItem]: name = spec.name default = spec.default ty = spec.ty @@ -1110,7 +1111,7 @@ def extend_for_any(spec: OptimItem) -> List[OptimItem]: @staticmethod # pyre-ignore[3] def create( - arg_spec: List[OptimItem], additional_spec: Optional[dict[str, Any]] = None + arg_spec: list[OptimItem], additional_spec: dict[str, Any] | None = None ): return OptimizerArgsSet( *( diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index 0a3f370d38..d041bb8007 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -8,7 +8,7 @@ # pyre-strict # flake8: noqa F401 -from typing import Any, Dict +from typing import Any try: from .jinja_environment import generate_optimized_grad_sum_loop_access @@ -29,7 +29,7 @@ ###################################################################### -def dense() -> Dict[str, Any]: +def dense() -> dict[str, Any]: return { "optimizer": "dense", "dense": True, @@ -49,7 +49,7 @@ def dense() -> Dict[str, Any]: } -def adagrad() -> Dict[str, Any]: +def adagrad() -> dict[str, Any]: split_weight_update = """ Vec4T m_t(&momentum1[idx * D + d]); m_t.acc.x += grad.acc.x * grad.acc.x; @@ -117,7 +117,7 @@ def table_info_precomputation(momentum_prefix: str = "momentum1") -> str: return template.replace("{momentum_prefix}", momentum_prefix) -def rowwise_adagrad() -> Dict[str, Any]: +def rowwise_adagrad() -> dict[str, Any]: split_weight_update = """ weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x; weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y; @@ -286,7 +286,7 @@ def rowwise_adagrad() -> Dict[str, Any]: } -def approx_rowwise_adagrad() -> Dict[str, Any]: +def approx_rowwise_adagrad() -> dict[str, Any]: rowwise_adagrad_args = rowwise_adagrad() approx_split_weight_update = """ @@ -322,7 +322,7 @@ def approx_rowwise_adagrad() -> Dict[str, Any]: # Deprecated, to be cleaned up -def rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: +def rowwise_adagrad_with_weight_decay() -> dict[str, Any]: split_weight_update = """ weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x; weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y; @@ -432,7 +432,7 @@ def rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: # Deprecated, to be cleaned up -def approx_rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: +def approx_rowwise_adagrad_with_weight_decay() -> dict[str, Any]: rowwise_adagrad_with_weight_decay_args = rowwise_adagrad_with_weight_decay() approx_split_weight_update = """ @@ -471,7 +471,7 @@ def approx_rowwise_adagrad_with_weight_decay() -> Dict[str, Any]: } -def rowwise_adagrad_with_counter() -> Dict[str, Any]: +def rowwise_adagrad_with_counter() -> dict[str, Any]: split_weight_update = """ weight_new.acc.x = (exp_reg_correction * weight_new.acc.x - adjusted_multiplier * grad.acc.x); weight_new.acc.y = (exp_reg_correction * weight_new.acc.y - adjusted_multiplier * grad.acc.y); @@ -735,7 +735,7 @@ def rowwise_adagrad_with_counter() -> Dict[str, Any]: } -def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]: +def approx_rowwise_adagrad_with_counter() -> dict[str, Any]: rowwise_adagrad_with_counter_args = rowwise_adagrad_with_counter() approx_split_weight_update = """ @@ -789,7 +789,7 @@ def approx_rowwise_adagrad_with_counter() -> Dict[str, Any]: # Deprecated, to be cleaned up -def rowwise_weighted_adagrad() -> Dict[str, Any]: +def rowwise_weighted_adagrad() -> dict[str, Any]: split_weight_update = """ weight_new.acc.x = correction * weight_new.acc.x - multiplier * grad.acc.x; weight_new.acc.y = correction * weight_new.acc.y - multiplier * grad.acc.y; @@ -870,7 +870,7 @@ def rowwise_weighted_adagrad() -> Dict[str, Any]: } -def sgd() -> Dict[str, Any]: +def sgd() -> dict[str, Any]: split_weight_update = """ weight_new.fma_(grad, -learning_rate); """ @@ -898,7 +898,7 @@ def sgd() -> Dict[str, Any]: } -def approx_sgd() -> Dict[str, Any]: +def approx_sgd() -> dict[str, Any]: sgd_args = sgd() approx_split_weight_update = """ @@ -926,7 +926,7 @@ def approx_sgd() -> Dict[str, Any]: } -def lamb() -> Dict[str, Any]: +def lamb() -> dict[str, Any]: split_precomputation = """ at::acc_type weight_sum_sq = 0.0; at::acc_type rtw_sum_sq = 0.0; @@ -1008,7 +1008,7 @@ def lamb() -> Dict[str, Any]: } -def partial_rowwise_lamb() -> Dict[str, Any]: +def partial_rowwise_lamb() -> dict[str, Any]: split_precomputation = """ at::acc_type g_local_sum_square = 0.0; """ @@ -1105,7 +1105,7 @@ def partial_rowwise_lamb() -> Dict[str, Any]: } -def adam() -> Dict[str, Any]: +def adam() -> dict[str, Any]: split_precomputation = """ // Define the optimizer state (for use with optimizer offloading) struct OptimizerState { @@ -1237,7 +1237,7 @@ def adam() -> Dict[str, Any]: } -def partial_rowwise_adam() -> Dict[str, Any]: +def partial_rowwise_adam() -> dict[str, Any]: split_precomputation = """ at::acc_type g_local_sum_square = 0.0; """ @@ -1371,7 +1371,7 @@ def partial_rowwise_adam() -> Dict[str, Any]: } -def lars_sgd() -> Dict[str, Any]: +def lars_sgd() -> dict[str, Any]: split_precomputation = """ at::acc_type weight_sum_sq = 0.0; at::acc_type grad_sum_sq = 0.0; @@ -1440,7 +1440,7 @@ def lars_sgd() -> Dict[str, Any]: } -def none_optimizer() -> Dict[str, Any]: +def none_optimizer() -> dict[str, Any]: return { "optimizer": "none", "dense": False, diff --git a/fbgemm_gpu/codegen/genscript/scripts_argsparse.py b/fbgemm_gpu/codegen/genscript/scripts_argsparse.py index 13f6fec050..171a96b926 100644 --- a/fbgemm_gpu/codegen/genscript/scripts_argsparse.py +++ b/fbgemm_gpu/codegen/genscript/scripts_argsparse.py @@ -8,7 +8,6 @@ # flake8: noqa F401 import argparse -from typing import List ################################################################################ # Parse Codegen Scripts' Arguments @@ -24,7 +23,7 @@ parser.add_argument("--is_rocm", action="store_true") args: argparse.Namespace -_: List[str] +_: list[str] args, _ = parser.parse_known_args() print(f"[ARGS PARSE] Parsed arguments: {args}")