|
7 | 7 | import itertools |
8 | 8 | import math |
9 | 9 | import pathlib |
| 10 | +import re |
10 | 11 | import subprocess |
11 | 12 | import textwrap |
12 | 13 |
|
|
18 | 19 | from ninetoothed.cudaifier import Cudaifier |
19 | 20 | from ninetoothed.language import attribute, call |
20 | 21 | from ninetoothed.symbol import Symbol |
21 | | -from ninetoothed.tensor import Tensor |
| 22 | +from ninetoothed.tensor import Tensor, _identifier_pattern_raw_string |
22 | 23 | from ninetoothed.torchifier import Torchifier |
23 | 24 |
|
24 | 25 | CACHE_DIR = pathlib.Path.home() / ".ninetoothed" |
@@ -798,6 +799,90 @@ def _name_for_seq_end(tensor): |
798 | 799 | def _name_for_index(tensor, dim): |
799 | 800 | return Symbol(f"{tensor.source.name}_index_{dim}") |
800 | 801 |
|
| 802 | + @staticmethod |
| 803 | + def _generate_name_mapping_from_tensors(tensors): |
| 804 | + return CodeGenerator._generate_name_mapping_from_shapes( |
| 805 | + tuple(tensor.shape for tensor in tensors) |
| 806 | + ) |
| 807 | + |
| 808 | + @staticmethod |
| 809 | + def _generate_name_mapping_from_shapes(shapes): |
| 810 | + if len(shapes) < 2: |
| 811 | + return {} |
| 812 | + |
| 813 | + all_mapping = collections.defaultdict(set) |
| 814 | + |
| 815 | + for input_shape, other_shape in itertools.combinations(shapes, 2): |
| 816 | + mapping = CodeGenerator._generate_name_mapping_from_shape_pair( |
| 817 | + input_shape, other_shape |
| 818 | + ) |
| 819 | + |
| 820 | + for key, value in mapping.items(): |
| 821 | + all_mapping[key] |= value |
| 822 | + |
| 823 | + return all_mapping |
| 824 | + |
| 825 | + @staticmethod |
| 826 | + def _generate_name_mapping_from_shape_pair(input, other): |
| 827 | + def _convert_shape_to_string_tuple(shape): |
| 828 | + return tuple(str(size) for size in shape) |
| 829 | + |
| 830 | + return CodeGenerator._generate_name_mapping_from_tuple_pair( |
| 831 | + _convert_shape_to_string_tuple(input), _convert_shape_to_string_tuple(other) |
| 832 | + ) |
| 833 | + |
| 834 | + @staticmethod |
| 835 | + def _generate_name_mapping_from_tuple_pair(input, other): |
| 836 | + all_mapping = collections.defaultdict(set) |
| 837 | + |
| 838 | + for input_string, other_string in zip(input, other): |
| 839 | + mapping = CodeGenerator._generate_name_mapping_from_string_pair( |
| 840 | + input_string, other_string |
| 841 | + ) |
| 842 | + |
| 843 | + for key, value in mapping.items(): |
| 844 | + all_mapping[key].add(value) |
| 845 | + |
| 846 | + return all_mapping |
| 847 | + |
| 848 | + @staticmethod |
| 849 | + def _generate_name_mapping_from_string_pair(input, other): |
| 850 | + name_pattern = _identifier_pattern_raw_string() |
| 851 | + |
| 852 | + input_names = set(re.findall(name_pattern, input)) |
| 853 | + other_names = set(re.findall(name_pattern, other)) |
| 854 | + names = input_names | other_names |
| 855 | + |
| 856 | + local_dict = {name: sympy.Symbol(name) for name in names} |
| 857 | + |
| 858 | + input_expr = sympy.simplify(eval(input, {}, local_dict)) |
| 859 | + other_expr = sympy.simplify(eval(other, {}, local_dict)) |
| 860 | + |
| 861 | + input_free_symbols = input_expr.free_symbols |
| 862 | + other_free_symbols = other_expr.free_symbols |
| 863 | + |
| 864 | + common_free_symbols = input_free_symbols & other_free_symbols |
| 865 | + |
| 866 | + input_unique_symbols = list(input_free_symbols - common_free_symbols) |
| 867 | + other_unique_symbols = list(other_free_symbols - common_free_symbols) |
| 868 | + |
| 869 | + mapping = {} |
| 870 | + |
| 871 | + if len(input_unique_symbols) == 1 and len(other_unique_symbols) == 1: |
| 872 | + input_unique_symbol = input_unique_symbols[0] |
| 873 | + other_unique_symbol = other_unique_symbols[0] |
| 874 | + |
| 875 | + substituted_expr = input_expr.subs(input_unique_symbol, other_unique_symbol) |
| 876 | + |
| 877 | + if sympy.simplify(substituted_expr - other_expr) == 0: |
| 878 | + input_unique_string = str(input_unique_symbol) |
| 879 | + other_unique_string = str(other_unique_symbol) |
| 880 | + |
| 881 | + mapping[input_unique_string] = other_unique_string |
| 882 | + mapping[other_unique_string] = input_unique_string |
| 883 | + |
| 884 | + return mapping |
| 885 | + |
801 | 886 |
|
802 | 887 | class Tritonizer(ast.NodeTransformer): |
803 | 888 | def visit_Module(self, node): |
|
0 commit comments