diff --git a/backends/xnnpack/partition/config/__init__.py b/backends/xnnpack/partition/config/__init__.py index 26ac6275ef1..d0a3e94bbc9 100644 --- a/backends/xnnpack/partition/config/__init__.py +++ b/backends/xnnpack/partition/config/__init__.py @@ -55,7 +55,9 @@ SubConfig, TanhConfig, ToDimOrderCopyConfig, + UnsqueezeCopyConfig, UpsampleBilinear2dConfig, + ViewCopyConfig, ) from executorch.backends.xnnpack.partition.config.node_configs import ( BatchNormConfig, @@ -116,7 +118,9 @@ SoftmaxConfig, SquareRootConfig, SubConfig, + UnsqueezeCopyConfig, UpsampleBilinear2dConfig, + ViewCopyConfig, # Quant/Dequant Op Configs QuantizedPerTensorConfig, DeQuantizedPerTensorConfig, diff --git a/backends/xnnpack/partition/config/generic_node_configs.py b/backends/xnnpack/partition/config/generic_node_configs.py index 0e588af66cb..1d933faa902 100644 --- a/backends/xnnpack/partition/config/generic_node_configs.py +++ b/backends/xnnpack/partition/config/generic_node_configs.py @@ -384,7 +384,7 @@ class ViewCopyConfig(GenericNodePartitionerConfig): target_name = "view_copy.default" def supported_precision_types(self) -> List[ConfigPrecisionType]: - return [ConfigPrecisionType.FP32] + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: """ @@ -722,3 +722,24 @@ class CosConfig(GenericNodePartitionerConfig): def supported_precision_types(self) -> List[ConfigPrecisionType]: return [ConfigPrecisionType.FP32] + + +class UnsqueezeCopyConfig(GenericNodePartitionerConfig): + target_name = "unsqueeze_copy.default" + + def supported_precision_types(self) -> List[ConfigPrecisionType]: + return [ConfigPrecisionType.FP32, ConfigPrecisionType.STATIC_QUANT] + + def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool: + if not self.check_common_constraints(node, ep): + return False + + # The XNNPACK UnsqueezeVisitor only supports unsqueeze on the trailing + # dimension. Mirrors the runtime check in op_squeeze.py. + dim = node.args[1] + input_rank = len(node.args[0].meta["val"].shape) + if dim != -1 and dim != input_rank: + why(node, reason="unsqueeze_copy only supported on the trailing dimension") + return False + + return True