Skip to content

Commit 4b56c07

Browse files
committed
Add methods to generate name mappings based on the outermost level of the tensors
1 parent 8f7dddd commit 4b56c07

1 file changed

Lines changed: 86 additions & 1 deletion

File tree

src/ninetoothed/generation.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import itertools
88
import math
99
import pathlib
10+
import re
1011
import subprocess
1112
import textwrap
1213

@@ -18,7 +19,7 @@
1819
from ninetoothed.cudaifier import Cudaifier
1920
from ninetoothed.language import attribute, call
2021
from ninetoothed.symbol import Symbol
21-
from ninetoothed.tensor import Tensor
22+
from ninetoothed.tensor import Tensor, _identifier_pattern_raw_string
2223
from ninetoothed.torchifier import Torchifier
2324

2425
CACHE_DIR = pathlib.Path.home() / ".ninetoothed"
@@ -798,6 +799,90 @@ def _name_for_seq_end(tensor):
798799
def _name_for_index(tensor, dim):
799800
return Symbol(f"{tensor.source.name}_index_{dim}")
800801

802+
@staticmethod
803+
def _generate_name_mapping_from_tensors(tensors):
804+
return CodeGenerator._generate_name_mapping_from_shapes(
805+
tuple(tensor.shape for tensor in tensors)
806+
)
807+
808+
@staticmethod
809+
def _generate_name_mapping_from_shapes(shapes):
810+
if len(shapes) < 2:
811+
return {}
812+
813+
all_mapping = collections.defaultdict(set)
814+
815+
for input_shape, other_shape in itertools.combinations(shapes, 2):
816+
mapping = CodeGenerator._generate_name_mapping_from_shape_pair(
817+
input_shape, other_shape
818+
)
819+
820+
for key, value in mapping.items():
821+
all_mapping[key] |= value
822+
823+
return all_mapping
824+
825+
@staticmethod
826+
def _generate_name_mapping_from_shape_pair(input, other):
827+
def _convert_shape_to_string_tuple(shape):
828+
return tuple(str(size) for size in shape)
829+
830+
return CodeGenerator._generate_name_mapping_from_tuple_pair(
831+
_convert_shape_to_string_tuple(input), _convert_shape_to_string_tuple(other)
832+
)
833+
834+
@staticmethod
835+
def _generate_name_mapping_from_tuple_pair(input, other):
836+
all_mapping = collections.defaultdict(set)
837+
838+
for input_string, other_string in zip(input, other):
839+
mapping = CodeGenerator._generate_name_mapping_from_string_pair(
840+
input_string, other_string
841+
)
842+
843+
for key, value in mapping.items():
844+
all_mapping[key].add(value)
845+
846+
return all_mapping
847+
848+
@staticmethod
849+
def _generate_name_mapping_from_string_pair(input, other):
850+
name_pattern = _identifier_pattern_raw_string()
851+
852+
input_names = set(re.findall(name_pattern, input))
853+
other_names = set(re.findall(name_pattern, other))
854+
names = input_names | other_names
855+
856+
local_dict = {name: sympy.Symbol(name) for name in names}
857+
858+
input_expr = sympy.simplify(eval(input, {}, local_dict))
859+
other_expr = sympy.simplify(eval(other, {}, local_dict))
860+
861+
input_free_symbols = input_expr.free_symbols
862+
other_free_symbols = other_expr.free_symbols
863+
864+
common_free_symbols = input_free_symbols & other_free_symbols
865+
866+
input_unique_symbols = list(input_free_symbols - common_free_symbols)
867+
other_unique_symbols = list(other_free_symbols - common_free_symbols)
868+
869+
mapping = {}
870+
871+
if len(input_unique_symbols) == 1 and len(other_unique_symbols) == 1:
872+
input_unique_symbol = input_unique_symbols[0]
873+
other_unique_symbol = other_unique_symbols[0]
874+
875+
substituted_expr = input_expr.subs(input_unique_symbol, other_unique_symbol)
876+
877+
if sympy.simplify(substituted_expr - other_expr) == 0:
878+
input_unique_string = str(input_unique_symbol)
879+
other_unique_string = str(other_unique_symbol)
880+
881+
mapping[input_unique_string] = other_unique_string
882+
mapping[other_unique_string] = input_unique_string
883+
884+
return mapping
885+
801886

802887
class Tritonizer(ast.NodeTransformer):
803888
def visit_Module(self, node):

0 commit comments

Comments
 (0)