@@ -466,4 +466,250 @@ def parseNodeCtxt(self,
466466class PULPConvTrans2DParser (PULPFPConv2DParser ):
467467
468468 def __init__ (self , noBiasHoisting = True ):
469- super ().__init__ (noBiasHoisting )
469+ super ().__init__ (noBiasHoisting )
470+
471+ def parseNode (self , node : gs .Node ) -> bool :
472+ """Override to recognize ConvGradX instead of Conv"""
473+ # Temporarily change op to Conv for parent parsing
474+ original_op = node .op
475+ if node .op == 'ConvGradX' :
476+ node .op = 'Conv'
477+
478+ # Call parent parseNode
479+ wellFormed = super ().parseNode (node )
480+
481+ # Restore original op
482+ node .op = original_op
483+
484+ # Additional validation for ConvGradX
485+ if wellFormed and original_op == 'ConvGradX' :
486+ # ConvGradX should have 2 inputs: output_grad and weight
487+ return len (node .inputs ) == 2
488+
489+ return wellFormed
490+
491+ def parseNodeCtxt (self ,
492+ ctxt : NetworkContext ,
493+ node : gs .Node ,
494+ channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
495+ """Override for ConvGradX - swap input/output semantics"""
496+
497+ if node .op == 'ConvGradX' :
498+ # For ConvGradX: inputs are [output_grad, weight], output is input_grad
499+ # But parent expects: inputs are [input, weight], output is output
500+ # So we need to swap the semantics
501+
502+ # Temporarily swap input/output for parent parsing
503+ output_grad_name = node .inputs [0 ].name
504+ input_grad_name = node .outputs [0 ].name
505+
506+ # Get tensors
507+ output_grad = ctxt .lookup (output_grad_name )
508+ weight = ctxt .lookup (node .inputs [1 ].name )
509+
510+ # Create a temporary input tensor with output_grad's info as if it's the output
511+ # and output tensor with input_grad's info as if it's the input
512+ temp_input = node .inputs [0 ]
513+ temp_output = node .outputs [0 ]
514+
515+ # Swap
516+ node .inputs [0 ] = temp_output
517+ node .outputs [0 ] = temp_input
518+
519+ # Call parent
520+ newCtxt , ret = super ().parseNodeCtxt (ctxt , node , channels_first )
521+
522+ # Restore
523+ node .inputs [0 ] = temp_input
524+ node .outputs [0 ] = temp_output
525+
526+ if ret :
527+ # Fix the tensor names for ConvGradX
528+ self .operatorRepresentation ['data_in' ] = output_grad_name
529+ self .operatorRepresentation ['data_out' ] = input_grad_name
530+ self .operatorRepresentation ["has_bias" ] = "false"
531+ self .operatorRepresentation ["bias" ] = "NULL"
532+
533+ return newCtxt , ret
534+ else :
535+ return super ().parseNodeCtxt (ctxt , node , channels_first )
536+
537+ class PULPDWConvTrans2DParser (PULPFPDWConv2DParser ):
538+
539+ def __init__ (self , noBiasHoisting = True ):
540+ super ().__init__ (noBiasHoisting )
541+
542+ def parseNode (self , node : gs .Node ) -> bool :
543+ """Override to recognize ConvGradX instead of Conv"""
544+ # Temporarily change op to Conv for parent parsing
545+ original_op = node .op
546+ if node .op == 'ConvGradX' :
547+ node .op = 'Conv'
548+
549+ # Call parent parseNode
550+ wellFormed = super ().parseNode (node )
551+
552+ # Restore original op
553+ node .op = original_op
554+
555+ # Additional validation for ConvGradX
556+ if wellFormed and original_op == 'ConvGradX' :
557+ # ConvGradX should have 2 inputs: output_grad and weight
558+ return len (node .inputs ) == 2
559+
560+ return wellFormed
561+
562+ def parseNodeCtxt (self ,
563+ ctxt : NetworkContext ,
564+ node : gs .Node ,
565+ channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
566+ """Override for ConvGradX - swap input/output semantics"""
567+
568+ if node .op == 'ConvGradX' :
569+ # For ConvGradX: inputs are [output_grad, weight], output is input_grad
570+ # Temporarily swap input/output for parent parsing
571+ output_grad_name = node .inputs [0 ].name
572+ input_grad_name = node .outputs [0 ].name
573+
574+ # Swap
575+ temp_input = node .inputs [0 ]
576+ temp_output = node .outputs [0 ]
577+ node .inputs [0 ] = temp_output
578+ node .outputs [0 ] = temp_input
579+
580+ # Call parent
581+ newCtxt , ret = super ().parseNodeCtxt (ctxt , node , channels_first )
582+
583+ # Restore
584+ node .inputs [0 ] = temp_input
585+ node .outputs [0 ] = temp_output
586+
587+ if ret :
588+ # Fix the tensor names for ConvGradX
589+ self .operatorRepresentation ['data_in' ] = output_grad_name
590+ self .operatorRepresentation ['data_out' ] = input_grad_name
591+ self .operatorRepresentation ["weight" ] = ctxt .lookup (node .inputs [1 ].name ).name
592+ self .operatorRepresentation ["has_bias" ] = "false"
593+ self .operatorRepresentation ["bias" ] = "NULL"
594+
595+ return newCtxt , ret
596+ else :
597+ return super ().parseNodeCtxt (ctxt , node , channels_first )
598+
599+
600+ class PULPConvGradW2DParser (PULPFPConv2DParser ):
601+
602+ def __init__ (self , noBiasHoisting = True ):
603+ super ().__init__ (noBiasHoisting )
604+
605+ def parseNodeCtxt (self ,
606+ ctxt : NetworkContext ,
607+ node : gs .Node ,
608+ channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
609+ """Parse ConvGradW - need custom logic for input dimensions"""
610+
611+ if not self .parseNode (node ):
612+ return ctxt , False
613+
614+ # Get input tensors
615+ grad_out_tensor = ctxt .lookup (node .inputs [0 ].name )
616+ data_in_tensor = ctxt .lookup (node .inputs [1 ].name )
617+
618+ # Extract batch size
619+ batch = grad_out_tensor .shape [0 ]
620+
621+ # Extract dimensions
622+ C_out , H_out , W_out = grad_out_tensor .shape [1 ], grad_out_tensor .shape [2 ], grad_out_tensor .shape [3 ]
623+ C_in , H_in , W_in = data_in_tensor .shape [1 ], data_in_tensor .shape [2 ], data_in_tensor .shape [3 ]
624+
625+ # Store batch size
626+ self .operatorRepresentation ['batch' ] = batch
627+
628+ # Store dimensions
629+ self .operatorRepresentation ['ch_im_out' ] = C_out
630+ self .operatorRepresentation ['dim_im_out_x' ] = W_out
631+ self .operatorRepresentation ['dim_im_out_y' ] = H_out
632+ self .operatorRepresentation ['ch_im_in' ] = C_in
633+ self .operatorRepresentation ['dim_im_in_x' ] = W_in
634+ self .operatorRepresentation ['dim_im_in_y' ] = H_in
635+
636+ # Store kernel dimensions
637+ self .operatorRepresentation ['dim_kernel_y' ] = self .operatorRepresentation ['kernel_shape' ][0 ]
638+ self .operatorRepresentation ['dim_kernel_x' ] = self .operatorRepresentation ['kernel_shape' ][1 ]
639+
640+ # Store strides
641+ self .operatorRepresentation ['stride_y' ] = self .operatorRepresentation ['strides' ][0 ]
642+ self .operatorRepresentation ['stride_x' ] = self .operatorRepresentation ['strides' ][1 ]
643+
644+ # Set tensor names and types
645+ self .operatorRepresentation ['grad_out' ] = node .inputs [0 ].name
646+ self .operatorRepresentation ['grad_out_type' ] = grad_out_tensor ._type
647+ self .operatorRepresentation ['data_in' ] = node .inputs [1 ].name
648+ self .operatorRepresentation ['data_in_type' ] = data_in_tensor ._type
649+ self .operatorRepresentation ['weight' ] = node .outputs [0 ].name
650+ self .operatorRepresentation ['weight_type' ] = grad_out_tensor ._type # Same as grad_out
651+
652+ # No bias for ConvGradW
653+ self .operatorRepresentation ['has_bias' ] = 'false'
654+ self .operatorRepresentation ['bias' ] = 'NULL'
655+
656+ return ctxt , True
657+
658+ class PULPConvGradB2DParser (PULPFPConv2DParser ):
659+
660+ def __init__ (self ):
661+ self .operatorRepresentation = {}
662+
663+ def parseNode (self , node : gs .Node ) -> bool :
664+ """Parse ConvGradB node attributes"""
665+
666+ # Check basic structure
667+ if node .op != 'ConvGradB' :
668+ return False
669+
670+ if len (node .inputs ) != 1 : # only output_grad
671+ return False
672+
673+ if len (node .outputs ) != 1 : # bias_grad
674+ return False
675+
676+ return True
677+
678+ def parseNodeCtxt (self ,
679+ ctxt : NetworkContext ,
680+ node : gs .Node ,
681+ channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
682+ """Parse ConvGradB node context"""
683+
684+ # For ConvGradB, the inputs are:
685+ # inputs[0]: output_grad [N, C_out, H_out, W_out] (NCHW)
686+ # output: bias_grad [C_out]
687+
688+ # Get tensors from context
689+ output_grad_tensor = ctxt .lookup (node .inputs [0 ].name )
690+
691+ # Extract batch size and dimensions (NCHW)
692+ batch = output_grad_tensor .shape [0 ]
693+ C_out = output_grad_tensor .shape [1 ]
694+ H_out = output_grad_tensor .shape [2 ]
695+ W_out = output_grad_tensor .shape [3 ]
696+
697+ # Store batch size
698+ self .operatorRepresentation ['batch' ] = batch
699+
700+ # Store dimensions
701+ self .operatorRepresentation ['ch_im_out' ] = C_out
702+ self .operatorRepresentation ['dim_im_out_x' ] = W_out
703+ self .operatorRepresentation ['dim_im_out_y' ] = H_out
704+
705+ # Dummy kernel_shape for computeOps (ConvGradB doesn't use kernels)
706+ self .operatorRepresentation ['kernel_shape' ] = [1 , 1 ]
707+ self .operatorRepresentation ['ch_im_in' ] = 1 # Dummy value
708+
709+ # Set tensor names and types
710+ self .operatorRepresentation ['grad_out' ] = node .inputs [0 ].name
711+ self .operatorRepresentation ['grad_out_type' ] = output_grad_tensor ._type
712+ self .operatorRepresentation ['bias' ] = node .outputs [0 ].name
713+ self .operatorRepresentation ['bias_type' ] = output_grad_tensor ._type # Same type as grad_out
714+
715+ return ctxt , True
0 commit comments