99from executorch .backends .arm ._passes import ArmPass
1010from executorch .backends .arm .tosa .specification import get_context_spec
1111from executorch .exir .dialects ._ops import ops as exir_ops
12- from executorch .exir .pass_base import ExportPass
12+ from executorch .exir .pass_base import ExportPass , NodeMetadata
1313
1414
1515class InsertDataLayoutCastsPass (ArmPass ):
@@ -55,6 +55,7 @@ class InsertDataLayoutCastsPass(ArmPass):
5555 }
5656
5757 _int_to_fp_map = {
58+ torch .int8 : torch .float16 , # This doubles the size after casting, but is very unlikely to occur in practice since int8 is only ever used by LOGICAL_SHIFT and CAST/RESCALE ops in PRO-FP.
5859 torch .int16 : torch .float16 ,
5960 torch .int32 : torch .float32 ,
6061 }
@@ -63,9 +64,15 @@ def call_operator(self, op, args, kwargs, meta):
6364 if op not in self .targeted_ops :
6465 return super ().call_operator (op , args , kwargs , meta )
6566
66- dtype = args [0 ].data .dtype
67- spec = get_context_spec ()
67+ if op in self ._concat_ops :
68+ # Cast to largest dtype
69+ dtypes = [arg .data .dtype for arg in args [0 ]]
70+ dtype_sizes = [dtype .itemsize for dtype in dtypes ]
71+ dtype = dtypes [dtype_sizes .index (max (dtype_sizes ))]
72+ else :
73+ dtype = args [0 ].data .dtype
6874
75+ spec = get_context_spec ()
6976 dtype_is_integer = not dtype .is_floating_point and dtype != torch .bool
7077 if dtype_is_integer and not spec .support_integer ():
7178 supported_dtype = self ._int_to_fp_map .get (dtype , None )
@@ -93,16 +100,30 @@ def call_operator(self, op, args, kwargs, meta):
93100 for arg in args [0 ]:
94101 x_casted .append (
95102 super ().call_operator (
96- self ._cast_op , (arg ,), {"dtype" : supported_dtype }, meta
103+ self ._cast_op ,
104+ (arg ,),
105+ {"dtype" : supported_dtype },
106+ NodeMetadata (arg .node .meta ),
107+ updated = True ,
97108 )
98109 )
99- y_casted = super ().call_operator (op , (x_casted ,), kwargs , meta )
110+ y_casted = super ().call_operator (
111+ op , (x_casted , * args [1 :]), kwargs , meta , updated = True
112+ )
100113
101114 else :
102115 x_casted = super ().call_operator (
103- self ._cast_op , (args [0 ],), {"dtype" : supported_dtype }, meta
116+ self ._cast_op ,
117+ (args [0 ],),
118+ {"dtype" : supported_dtype },
119+ NodeMetadata (args [0 ].node .meta ),
120+ updated = True ,
121+ )
122+ y_casted = super ().call_operator (
123+ op , (x_casted , * args [1 :]), kwargs , meta , updated = True
104124 )
105- y_casted = super ().call_operator (op , (x_casted , * args [1 :]), kwargs , meta )
106125
107- y = super ().call_operator (self ._cast_op , (y_casted ,), {"dtype" : dtype }, meta )
126+ y = super ().call_operator (
127+ self ._cast_op , (y_casted ,), {"dtype" : dtype }, meta , updated = True
128+ )
108129 return y
0 commit comments