File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -162,14 +162,31 @@ def targets(self) -> list[EdgeOpOverload]:
162162
163163 def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
164164 ns = exir_ops .edge if isinstance (node .target , EdgeOpOverload ) else torch .ops
165+ out_dtype = node .kwargs .get ("out_dtype" )
166+ kwargs = {k : v for k , v in node .kwargs .items () if k != "out_dtype" }
165167 with node .graph .inserting_before (node ):
166168 new_node = node .graph .call_function (
167169 ns .cadence .dequantize_per_tensor .default ,
168170 args = node .args ,
169- kwargs = node . kwargs ,
171+ kwargs = kwargs ,
170172 )
171- new_node .meta = node .meta
172- node .replace_all_uses_with (new_node )
173+ new_node .meta = node .meta .copy ()
174+ if (
175+ out_dtype is not None
176+ and out_dtype != torch .float32
177+ and "val" in new_node .meta
178+ ):
179+ new_node .meta ["val" ] = new_node .meta ["val" ].to (torch .float32 )
180+ if out_dtype is not None and out_dtype != torch .float32 :
181+ with node .graph .inserting_after (new_node ):
182+ cast_node = node .graph .call_function (
183+ ns .aten .to .dtype ,
184+ args = (new_node , out_dtype ),
185+ )
186+ cast_node .meta = node .meta .copy ()
187+ node .replace_all_uses_with (cast_node )
188+ else :
189+ node .replace_all_uses_with (new_node )
173190 return True
174191
175192
You can’t perform that action at this time.
0 commit comments