@@ -597,11 +597,38 @@ def parseNodeCtxt(self,
597597 return super ().parseNodeCtxt (ctxt , node , channels_first )
598598
599599
600- class PULPConvGradW2DParser (PULPFPConv2DParser ):
600+ class PULPConvGradW2DParser (Conv2DParser ):
601+ """Parser for standard ConvGradW (non-grouped)"""
601602
602603 def __init__ (self , noBiasHoisting = True ):
603604 super ().__init__ (noBiasHoisting )
604605
606+ def parseNode (self , node : gs .Node ) -> bool :
607+ """Parse ConvGradW node, rejecting grouped convolutions"""
608+ # Call Conv2DParser.parseNode directly (skip PULPFPConv2DParser's group==1 check)
609+ wellFormed = Conv2DParser .parseNode (self , node )
610+
611+ if not wellFormed :
612+ return False
613+
614+ # Reject if group > 1 (handled by DWConvGradW2DParser)
615+ if 'group' in self .operatorRepresentation :
616+ group = self .operatorRepresentation ['group' ]
617+ if group > 1 :
618+ return False
619+
620+ # ConvGradW has 2 inputs: output_grad and input_data
621+ if len (node .inputs ) != 2 :
622+ return False
623+
624+ # Extract padding attributes
625+ self .operatorRepresentation ['padding_y_top' ] = int (self .operatorRepresentation ['pads' ][0 ])
626+ self .operatorRepresentation ['padding_x_left' ] = int (self .operatorRepresentation ['pads' ][1 ])
627+ self .operatorRepresentation ['padding_y_bottom' ] = int (self .operatorRepresentation ['pads' ][2 ])
628+ self .operatorRepresentation ['padding_x_right' ] = int (self .operatorRepresentation ['pads' ][3 ])
629+
630+ return True
631+
605632 def parseNodeCtxt (self ,
606633 ctxt : NetworkContext ,
607634 node : gs .Node ,
@@ -712,4 +739,103 @@ def parseNodeCtxt(self,
712739 self .operatorRepresentation ['bias' ] = node .outputs [0 ].name
713740 self .operatorRepresentation ['bias_type' ] = output_grad_tensor ._type # Same type as grad_out
714741
715- return ctxt , True
742+ return ctxt , True
743+
744+
745+ class PULPDWConvGradW2DParser (Conv2DParser ):
746+ """Parser for depthwise ConvGradW (grouped convolution weight gradient)"""
747+
748+ def __init__ (self , noBiasHoisting = True ):
749+ super ().__init__ (noBiasHoisting )
750+
751+ def parseNode (self , node : gs .Node ) -> bool :
752+ """Parse grouped ConvGradW node"""
753+ # Call Conv2DParser.parseNode directly (skip PULPFPConv2DParser's group==1 check)
754+ wellFormed = Conv2DParser .parseNode (self , node )
755+
756+ if not wellFormed :
757+ return False
758+
759+ # Must have group attribute and group > 1
760+ if 'group' not in self .operatorRepresentation :
761+ return False
762+
763+ group = self .operatorRepresentation ['group' ]
764+ if group <= 1 :
765+ return False
766+
767+ # ConvGradW has 2 inputs: output_grad and input_data
768+ if len (node .inputs ) != 2 :
769+ return False
770+
771+ # Extract padding attributes
772+ self .operatorRepresentation ['padding_y_top' ] = int (self .operatorRepresentation ['pads' ][0 ])
773+ self .operatorRepresentation ['padding_x_left' ] = int (self .operatorRepresentation ['pads' ][1 ])
774+ self .operatorRepresentation ['padding_y_bottom' ] = int (self .operatorRepresentation ['pads' ][2 ])
775+ self .operatorRepresentation ['padding_x_right' ] = int (self .operatorRepresentation ['pads' ][3 ])
776+
777+ return True
778+
779+ def parseNodeCtxt (self ,
780+ ctxt : NetworkContext ,
781+ node : gs .Node ,
782+ channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
783+ """Parse DWConvGradW - depthwise/grouped weight gradient computation"""
784+
785+ if not self .parseNode (node ):
786+ return ctxt , False
787+
788+ # Get input tensors
789+ grad_out_tensor = ctxt .lookup (node .inputs [0 ].name )
790+ data_in_tensor = ctxt .lookup (node .inputs [1 ].name )
791+
792+ # Extract batch size
793+ batch = grad_out_tensor .shape [0 ]
794+
795+ # Extract dimensions (NCHW format)
796+ C_out , H_out , W_out = grad_out_tensor .shape [1 ], grad_out_tensor .shape [2 ], grad_out_tensor .shape [3 ]
797+ C_in , H_in , W_in = data_in_tensor .shape [1 ], data_in_tensor .shape [2 ], data_in_tensor .shape [3 ]
798+
799+ # Get group info
800+ group = self .operatorRepresentation ['group' ]
801+
802+ # Verify grouping constraints
803+ assert C_out % group == 0 , f"Output channels { C_out } not divisible by group { group } "
804+ assert C_in % group == 0 , f"Input channels { C_in } not divisible by group { group } "
805+
806+ # For depthwise: group == C_in == C_out
807+ # Weight shape is [C_out, C_in/group, kH, kW]
808+ C_in_per_group = C_in // group
809+
810+ # Store batch size
811+ self .operatorRepresentation ['batch' ] = batch
812+
813+ # Store dimensions
814+ self .operatorRepresentation ['ch_im_out' ] = C_out
815+ self .operatorRepresentation ['dim_im_out_x' ] = W_out
816+ self .operatorRepresentation ['dim_im_out_y' ] = H_out
817+ self .operatorRepresentation ['ch_im_in' ] = C_in
818+ self .operatorRepresentation ['dim_im_in_x' ] = W_in
819+ self .operatorRepresentation ['dim_im_in_y' ] = H_in
820+
821+ # Store kernel dimensions
822+ self .operatorRepresentation ['dim_kernel_y' ] = self .operatorRepresentation ['kernel_shape' ][0 ]
823+ self .operatorRepresentation ['dim_kernel_x' ] = self .operatorRepresentation ['kernel_shape' ][1 ]
824+
825+ # Store strides
826+ self .operatorRepresentation ['stride_y' ] = self .operatorRepresentation ['strides' ][0 ]
827+ self .operatorRepresentation ['stride_x' ] = self .operatorRepresentation ['strides' ][1 ]
828+
829+ # Set tensor names and types
830+ self .operatorRepresentation ['grad_out' ] = node .inputs [0 ].name
831+ self .operatorRepresentation ['grad_out_type' ] = grad_out_tensor ._type
832+ self .operatorRepresentation ['data_in' ] = node .inputs [1 ].name
833+ self .operatorRepresentation ['data_in_type' ] = data_in_tensor ._type
834+ self .operatorRepresentation ['weight' ] = node .outputs [0 ].name
835+ self .operatorRepresentation ['weight_type' ] = grad_out_tensor ._type
836+
837+ # No bias for ConvGradW
838+ self .operatorRepresentation ['has_bias' ] = 'false'
839+ self .operatorRepresentation ['bias' ] = 'NULL'
840+
841+ return ctxt , True
0 commit comments