Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import itertools
import sys
from copy import deepcopy
from typing import List

try:
# pyre-fixme[21]: Could not find name `ArgType` in
Expand Down Expand Up @@ -45,7 +44,7 @@ def render_backward_templates(
template_filepath: str,
optimizer: str,
filename_format: str,
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
is_gwd: bool = False,
) -> None:
if not kwargs.get("has_gpu_support"):
Expand Down Expand Up @@ -352,14 +351,14 @@ def generate_rocm_backward_split(**kwargs: Any) -> None:

@staticmethod
def generate_backward_header(
aux_args: Dict[str, List[str]], aux_names: List[str], is_ssd: bool = False
aux_args: dict[str, list[str]], aux_names: list[str], is_ssd: bool = False
) -> None:
"""
Generate a header file that contains enum of argument order from the dict

Parameters:
aux_args (Dict[str, List[str]]): a dict containing a list of arguments
aux_names (List[str]): names of the argument types (e.g. aux_tensor, aux_int, etc.)
aux_args (dict[str, list[str]]): a dict containing a list of arguments
aux_names (list[str]): names of the argument types (e.g. aux_tensor, aux_int, etc.)
Return:
None
"""
Expand All @@ -372,7 +371,7 @@ def generate_backward_header(

@staticmethod
def generate_python_sources(
all_optimizers: List[str], ssd_optimizers: List[str]
all_optimizers: list[str], ssd_optimizers: list[str]
) -> None:
CodeTemplate.load("training/python/__init__.template").write(
"__init__.py", all_optimizers=all_optimizers, ssd_optimizers=ssd_optimizers
Expand Down Expand Up @@ -418,7 +417,7 @@ def generate() -> None:
# This is a dict of auxilary arguments used in TBE PT2 interface where the aux
# arguments of a type are packed into a list for that type. This dict maintains the
# order of the arguments of each type.
aux_args: Dict[str, List[str]] = {
aux_args: dict[str, list[str]] = {
"aux_tensor": [
"B_offsets", # 0
"vbe_output_offsets_feature_rank", # 1
Expand Down
5 changes: 2 additions & 3 deletions fbgemm_gpu/codegen/genscript/generate_forward_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import sys
from dataclasses import dataclass
from typing import Dict, List

try:
from .common import CodeTemplate
Expand All @@ -31,7 +30,7 @@ class ElemType:
cpp_type_name: str
primitive_type: str
bit_width: int
template_params: List[TemplateParams]
template_params: list[TemplateParams]

@property
def enum_name(self) -> str:
Expand Down Expand Up @@ -122,7 +121,7 @@ def enum_name(self) -> str:
),
]

ELEM_TYPES_MAP: Dict[str, ElemType] = {etype.enum_name: etype for etype in ELEM_TYPES}
ELEM_TYPES_MAP: dict[str, ElemType] = {etype.enum_name: etype for etype in ELEM_TYPES}


class ForwardQuantizedGenerator:
Expand Down
9 changes: 4 additions & 5 deletions fbgemm_gpu/codegen/genscript/generate_forward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import itertools
import sys
from typing import List

try:
from .common import CodeTemplate
Expand All @@ -28,10 +27,10 @@ class ForwardSplitGenerator:
def render_forward_templates(
template_filepath: str,
filename_format: str,
dense_options: List[bool],
nobag_options: List[bool],
vbe_options: List[bool],
ssd_options: List[bool],
dense_options: list[bool],
nobag_options: list[bool],
vbe_options: list[bool],
ssd_options: list[bool],
is_gwd: bool = False,
) -> None:
template = CodeTemplate.load(template_filepath)
Expand Down
23 changes: 11 additions & 12 deletions fbgemm_gpu/codegen/genscript/jinja_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import argparse
import os
import re
from typing import Dict, List, Optional, Tuple

import jinja2

Expand Down Expand Up @@ -81,7 +80,7 @@
################################################################################


def prepare_string_for_formatting(blob: str, format_keywords: List[str]) -> str:
def prepare_string_for_formatting(blob: str, format_keywords: list[str]) -> str:
"""
Replace curly brackets ('{' or '}') with escape characters ('{{' or '}}')
to prepare the string to be formatted by `str.format()`. `str.format()`
Expand All @@ -95,7 +94,7 @@ def prepare_string_for_formatting(blob: str, format_keywords: List[str]) -> str:


def generate_optimized_grad_sum_loop_access(
blob: str, other_formats: Optional[Dict[str, str]] = None
blob: str, other_formats: dict[str, str] | None = None
) -> str:
"""
Generate an optimized code for grad_sum accessing
Expand Down Expand Up @@ -148,13 +147,13 @@ def get_max_vecs_template_configs(
fixed_max_vecs_per_thread: int,
use_subwarp_shuffle: bool,
use_vec_blocking: bool,
) -> List[Tuple[int, int, str]]:
) -> list[tuple[int, int, str]]:
"""
Generate the template configs for each kFixedMaxVecsPerThread,
kThreadGroupSize, and kUseVecBlocking
"""
warp_size = items_per_warp // 4
configs: List[Tuple[int, int, str]] = []
configs: list[tuple[int, int, str]] = []

if use_vec_blocking:
# kFixedMaxVecsPerThread = fixed_max_vecs_per_thread
Expand Down Expand Up @@ -361,7 +360,7 @@ def compute_global_weight_decay(is_global_weight_decay_kernel: bool) -> str:


# Format the macro call to generate pta::PackedTensorAccessors
def make_pta_acc_format(pta_str_list: List[str], func_name: str) -> List[str]:
def make_pta_acc_format(pta_str_list: list[str], func_name: str) -> list[str]:
new_str_list = []
for pta_str in pta_str_list:
if "packed_accessor" in pta_str:
Expand All @@ -387,7 +386,7 @@ def make_pta_acc_format(pta_str_list: List[str], func_name: str) -> List[str]:
return new_str_list


def make_pta_acc_builder_format(pta_str_list: List[str]) -> List[str]:
def make_pta_acc_builder_format(pta_str_list: list[str]) -> list[str]:
new_str_list = []
for pta_str in pta_str_list:
if "packed_accessor" in pta_str:
Expand All @@ -411,7 +410,7 @@ def make_pta_acc_builder_format(pta_str_list: List[str]) -> List[str]:
return new_str_list


def replace_pta_namespace(pta_str_list: List[str]) -> List[str]:
def replace_pta_namespace(pta_str_list: list[str]) -> list[str]:
return [
pta_str.replace("at::PackedTensorAccessor", "pta::PackedTensorAccessor")
for pta_str in pta_str_list
Expand All @@ -420,10 +419,10 @@ def replace_pta_namespace(pta_str_list: List[str]) -> List[str]:

def replace_placeholder_types(
# pyre-fixme[11]: Annotation `TensorType` is not defined as a type.
arg_str_list: List[str],
arg_str_list: list[str],
# pyre-fixme[11]: Annotation `TensorType` is not defined as a type.
type_combo: Optional[Dict[str, TensorType]],
) -> List[str]:
type_combo: dict[str, TensorType] | None,
) -> list[str]:
"""
Replace the placeholder types with the primitive types
"""
Expand All @@ -439,7 +438,7 @@ def replace_placeholder_types(
return new_str_list


def to_upper_placeholder_types(arg_str_list: List[str]) -> List[str]:
def to_upper_placeholder_types(arg_str_list: list[str]) -> list[str]:
"""
Make the placeholder type names upper cases
"""
Expand Down
Loading
Loading