77from invokeai .backend .model_manager .load .model_cache .torch_module_autocast .custom_modules .utils import (
88 add_nullable_tensors ,
99)
10+ from invokeai .backend .quantization .gguf .ggml_tensor import GGMLTensor
1011
1112
1213class CustomConv2d (torch .nn .Conv2d , CustomModuleMixin ):
14+ def _cast_tensor_for_input (self , tensor : torch .Tensor | None , input : torch .Tensor ) -> torch .Tensor | None :
15+ tensor = cast_to_device (tensor , input .device )
16+ if (
17+ tensor is not None
18+ and input .is_floating_point ()
19+ and tensor .is_floating_point ()
20+ and not isinstance (tensor , GGMLTensor )
21+ and tensor .dtype != input .dtype
22+ ):
23+ tensor = tensor .to (dtype = input .dtype )
24+ return tensor
25+
1326 def _autocast_forward_with_patches (self , input : torch .Tensor ) -> torch .Tensor :
14- weight = cast_to_device (self .weight , input . device )
15- bias = cast_to_device (self .bias , input . device )
27+ weight = self . _cast_tensor_for_input (self .weight , input )
28+ bias = self . _cast_tensor_for_input (self .bias , input )
1629
1730 # Prepare the original parameters for the patch aggregation.
1831 orig_params = {"weight" : weight , "bias" : bias }
@@ -25,19 +38,40 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
2538 device = input .device ,
2639 )
2740
28- weight = add_nullable_tensors (weight , aggregated_param_residuals .get ("weight" , None ))
29- bias = add_nullable_tensors (bias , aggregated_param_residuals .get ("bias" , None ))
41+ residual_weight = self ._cast_tensor_for_input (aggregated_param_residuals .get ("weight" , None ), input )
42+ residual_bias = self ._cast_tensor_for_input (aggregated_param_residuals .get ("bias" , None ), input )
43+ weight = add_nullable_tensors (weight , residual_weight )
44+ bias = add_nullable_tensors (bias , residual_bias )
3045 return self ._conv_forward (input , weight , bias )
3146
3247 def _autocast_forward (self , input : torch .Tensor ) -> torch .Tensor :
33- weight = cast_to_device (self .weight , input . device )
34- bias = cast_to_device (self .bias , input . device )
48+ weight = self . _cast_tensor_for_input (self .weight , input )
49+ bias = self . _cast_tensor_for_input (self .bias , input )
3550 return self ._conv_forward (input , weight , bias )
3651
3752 def forward (self , input : torch .Tensor ) -> torch .Tensor :
3853 if len (self ._patches_and_weights ) > 0 :
3954 return self ._autocast_forward_with_patches (input )
4055 elif self ._device_autocasting_enabled :
4156 return self ._autocast_forward (input )
57+ elif (
58+ input .is_floating_point ()
59+ and (
60+ (
61+ self .weight .is_floating_point ()
62+ and not isinstance (self .weight , GGMLTensor )
63+ and self .weight .dtype != input .dtype
64+ )
65+ or (
66+ self .bias is not None
67+ and self .bias .is_floating_point ()
68+ and not isinstance (self .bias , GGMLTensor )
69+ and self .bias .dtype != input .dtype
70+ )
71+ )
72+ ):
73+ weight = self ._cast_tensor_for_input (self .weight , input )
74+ bias = self ._cast_tensor_for_input (self .bias , input )
75+ return self ._conv_forward (input , weight , bias )
4276 else :
4377 return super ().forward (input )
0 commit comments