File tree Expand file tree Collapse file tree
modelopt/torch/quantization/plugins Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -120,6 +120,13 @@ def _functionals_to_replace(self, value):
120120 self ._functionals_to_replace = value
121121
122122 def _setup (self ):
123+ if getattr (self , "fuse_wgrad_accumulation" , False ):
124+ warnings .warn (
125+ "fuse_wgrad_accumulation is not supported with ModelOpt quantization. "
126+ "Setting fuse_wgrad_accumulation to False."
127+ )
128+ self .fuse_wgrad_accumulation = False
129+
123130 # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
124131 # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
125132 # self.weight0 to self.weight to run the quantizer states initialization.
@@ -131,6 +138,9 @@ def _setup(self):
131138 # Remove self.weight after setup.
132139 delattr (self , "weight" )
133140
141+ # TODO: GroupedLinear supports weights split by `num_gemms`, to support quantization
142+ # with static parameters beyond per-tensor, we need to support a unique quantizer for each gemm.
143+
134144 def modelopt_post_restore (self , prefix : str = "" ):
135145 # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
136146 # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
You can’t perform that action at this time.
0 commit comments