Skip to content

Commit 034b044

Browse files
authored
Handle out_dtype in ReplacePT2DequantWithCadenceDequantPass (#19743)
Differential Revision: D105630451 Pull Request resolved: #19743
1 parent b4d62ed commit 034b044

1 file changed

Lines changed: 20 additions & 3 deletions

File tree

backends/cadence/aot/replace_ops.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,31 @@ def targets(self) -> list[EdgeOpOverload]:
162162

163163
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
164164
ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops
165+
out_dtype = node.kwargs.get("out_dtype")
166+
kwargs = {k: v for k, v in node.kwargs.items() if k != "out_dtype"}
165167
with node.graph.inserting_before(node):
166168
new_node = node.graph.call_function(
167169
ns.cadence.dequantize_per_tensor.default,
168170
args=node.args,
169-
kwargs=node.kwargs,
171+
kwargs=kwargs,
170172
)
171-
new_node.meta = node.meta
172-
node.replace_all_uses_with(new_node)
173+
new_node.meta = node.meta.copy()
174+
if (
175+
out_dtype is not None
176+
and out_dtype != torch.float32
177+
and "val" in new_node.meta
178+
):
179+
new_node.meta["val"] = new_node.meta["val"].to(torch.float32)
180+
if out_dtype is not None and out_dtype != torch.float32:
181+
with node.graph.inserting_after(new_node):
182+
cast_node = node.graph.call_function(
183+
ns.aten.to.dtype,
184+
args=(new_node, out_dtype),
185+
)
186+
cast_node.meta = node.meta.copy()
187+
node.replace_all_uses_with(cast_node)
188+
else:
189+
node.replace_all_uses_with(new_node)
173190
return True
174191

175192

0 commit comments

Comments
 (0)