@@ -227,7 +227,7 @@ def _NCHWtoNHWC_fun(graph: gs.Graph, match: Match, name: str, default_channels_f
227227 tensorIn = node .inputs [0 ]
228228 tensorOut = node .outputs [0 ]
229229
230- if node .op in ["RequantizedConv" , "Conv" ]:
230+ if node .op in ["RequantizedConv" , "Conv" , "ConvGradX" ]:
231231 spatialDims = len (node .inputs [1 ].shape ) - 2
232232 elif node .op in ["MaxPool" , "AveragePool" , "AveragePoolGrad" ]:
233233 spatialDims = len (node .attrs ["kernel_shape" ])
@@ -242,8 +242,9 @@ def _NCHWtoNHWC_fun(graph: gs.Graph, match: Match, name: str, default_channels_f
242242 permuteOut = _transformLayoutPermutation (len (tensorOut .shape ), spatialDims , channels_first )
243243 graph .nodes .append (_prependTranspose (tensorOut , node , permuteOut ))
244244
245- if node .op in ["Conv" , "RequantizedConv" ]:
245+ if node .op in ["Conv" , "RequantizedConv" , "ConvGradX" ]:
246246 # In the case of Conv: [weights, opt. bias], RequantizedConv: [weights, mul, add, opt. shift]
247+ # ConvGradX: [weight] (no bias)
247248 for tensor in node .inputs [1 :]:
248249 _transformLayoutConst (tensor , spatialDims , default_channels_first )
249250
@@ -279,6 +280,15 @@ def __init__(self, default_channels_first: bool = True):
279280 super ().__init__ (graph , partial (_NCHWtoNHWC_fun , default_channels_first = default_channels_first ), name )
280281
281282
283+ @contextagnostic
284+ class NCHWtoNHWCConvGradXPass (ReplaceSequentialPatternPass ):
285+
286+ def __init__ (self , default_channels_first : bool = True ):
287+ graph = _singleNodePattern (op = "ConvGradX" )
288+ name = "_NCHW_TO_NHWC_CONVGRADX_PASS"
289+ super ().__init__ (graph , partial (_NCHWtoNHWC_fun , default_channels_first = default_channels_first ), name )
290+
291+
282292@contextagnostic
283293class NCHWtoNHWCConvPass (ReplaceSequentialPatternPass ):
284294
@@ -383,6 +393,7 @@ def __init__(self, default_channels_first: bool = True):
383393 NCHWtoNHWCMaxPoolPass (default_channels_first ),
384394 NCHWtoNHWCAveragePoolPass (default_channels_first ),
385395 NCHWtoNHWCAveragePoolGradPass (default_channels_first ),
396+ NCHWtoNHWCConvGradXPass (default_channels_first ),
386397 NCHWtoNHWCDwConvPass (default_channels_first ),
387398 NCHWtoNHWCConvPass (default_channels_first ),
388399 ]
@@ -398,6 +409,7 @@ def __init__(self, default_channels_first: bool = True):
398409 NCHWtoNHWCMaxPoolPass (default_channels_first ),
399410 NCHWtoNHWCAveragePoolPass (default_channels_first ),
400411 NCHWtoNHWCAveragePoolGradPass (default_channels_first ),
412+ NCHWtoNHWCConvGradXPass (default_channels_first ),
401413 PULPNCHWtoNHWCDwConvPass (default_channels_first ),
402414 NCHWtoNHWCConvPass (default_channels_first ),
403415 ]
0 commit comments