@@ -55,7 +55,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
5555 data = weight_data , quantized_stats = quant_state_dict , device = device
5656 )
5757 self .weight = bnb_param
58-
58+ self ._bnb_quant_state_dict = quant_state_dict
59+
5960 for k in bnb_state_dict .keys ():
6061 state_dict .pop (k )
6162 if k in unexpected_keys : unexpected_keys .remove (k )
@@ -94,9 +95,10 @@ def forward(self, x):
9495 if getattr (self , "is_bnb_quantized" , lambda : False )():
9596 if not patches_for_this_layer :
9697 bias = self .bias .to (device = x .device , dtype = x .dtype ) if self .bias is not None else None
97- return bnb .matmul_4bit (
98- x , self .weight .t (), bias = bias , quant_state = getattr (self .weight , "quant_state" , None )
99- ).to (x .dtype )
98+ qs = getattr (self .weight , "quant_state" , None )
99+ if qs is None and hasattr (self , "_bnb_quant_state_dict" ):
100+ qs = bnb .functional .QuantState .from_dict (self ._bnb_quant_state_dict , device = x .device )
101+ return bnb .matmul_4bit (x , self .weight .t (), bias = bias , quant_state = qs ).to (x .dtype )
100102
101103 try :
102104 base_w = self .weight .to (x .device )
@@ -113,9 +115,10 @@ def forward(self, x):
113115
114116 if weight_final_fp32 is None :
115117 bias = self .bias .to (device = x .device , dtype = x .dtype ) if self .bias is not None else None
116- return bnb .matmul_4bit (
117- x , self .weight .t (), bias = bias , quant_state = getattr (self .weight , "quant_state" , None )
118- ).to (x .dtype )
118+ qs = getattr (self .weight , "quant_state" , None )
119+ if qs is None and hasattr (self , "_bnb_quant_state_dict" ):
120+ qs = bnb .functional .QuantState .from_dict (self ._bnb_quant_state_dict , device = x .device )
121+ return bnb .matmul_4bit (x , self .weight .t (), bias = bias , quant_state = qs ).to (x .dtype )
119122
120123 weight_final = comfy .float .stochastic_rounding (weight_final_fp32 , x .dtype )
121124 bias = self .bias .to (device = x .device , dtype = x .dtype ) if self .bias is not None else None
0 commit comments