Skip to content

Commit fdf6be5

Browse files
committed
Handle CustomConv2d bias dtype mismatches
1 parent 06eff38 commit fdf6be5

2 files changed

Lines changed: 63 additions & 8 deletions

File tree

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,25 @@
77
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import (
88
add_nullable_tensors,
99
)
10+
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
1011

1112

1213
class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin):
14+
def _cast_tensor_for_input(self, tensor: torch.Tensor | None, input: torch.Tensor) -> torch.Tensor | None:
15+
tensor = cast_to_device(tensor, input.device)
16+
if (
17+
tensor is not None
18+
and input.is_floating_point()
19+
and tensor.is_floating_point()
20+
and not isinstance(tensor, GGMLTensor)
21+
and tensor.dtype != input.dtype
22+
):
23+
tensor = tensor.to(dtype=input.dtype)
24+
return tensor
25+
1326
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
14-
weight = cast_to_device(self.weight, input.device)
15-
bias = cast_to_device(self.bias, input.device)
27+
weight = self._cast_tensor_for_input(self.weight, input)
28+
bias = self._cast_tensor_for_input(self.bias, input)
1629

1730
# Prepare the original parameters for the patch aggregation.
1831
orig_params = {"weight": weight, "bias": bias}
@@ -25,19 +38,40 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
2538
device=input.device,
2639
)
2740

28-
weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None))
29-
bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None))
41+
residual_weight = self._cast_tensor_for_input(aggregated_param_residuals.get("weight", None), input)
42+
residual_bias = self._cast_tensor_for_input(aggregated_param_residuals.get("bias", None), input)
43+
weight = add_nullable_tensors(weight, residual_weight)
44+
bias = add_nullable_tensors(bias, residual_bias)
3045
return self._conv_forward(input, weight, bias)
3146

3247
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
33-
weight = cast_to_device(self.weight, input.device)
34-
bias = cast_to_device(self.bias, input.device)
48+
weight = self._cast_tensor_for_input(self.weight, input)
49+
bias = self._cast_tensor_for_input(self.bias, input)
3550
return self._conv_forward(input, weight, bias)
3651

3752
def forward(self, input: torch.Tensor) -> torch.Tensor:
3853
if len(self._patches_and_weights) > 0:
3954
return self._autocast_forward_with_patches(input)
4055
elif self._device_autocasting_enabled:
4156
return self._autocast_forward(input)
57+
elif (
58+
input.is_floating_point()
59+
and (
60+
(
61+
self.weight.is_floating_point()
62+
and not isinstance(self.weight, GGMLTensor)
63+
and self.weight.dtype != input.dtype
64+
)
65+
or (
66+
self.bias is not None
67+
and self.bias.is_floating_point()
68+
and not isinstance(self.bias, GGMLTensor)
69+
and self.bias.dtype != input.dtype
70+
)
71+
)
72+
):
73+
weight = self._cast_tensor_for_input(self.weight, input)
74+
bias = self._cast_tensor_for_input(self.bias, input)
75+
return self._conv_forward(input, weight, bias)
4276
else:
4377
return super().forward(input)

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
1010
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
1111
from invokeai.backend.patches.layers.lora_layer import LoRALayer
12+
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
1213

1314

1415
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
@@ -73,18 +74,38 @@ def autocast_linear_forward_sidecar_patches(
7374

7475

7576
class CustomLinear(torch.nn.Linear, CustomModuleMixin):
77+
def _cast_weight_bias_for_input(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]:
78+
weight = cast_to_device(self.weight, input.device)
79+
bias = cast_to_device(self.bias, input.device)
80+
if (
81+
input.is_floating_point()
82+
and weight.is_floating_point()
83+
and not isinstance(weight, GGMLTensor)
84+
and weight.dtype != input.dtype
85+
):
86+
weight = weight.to(dtype=input.dtype)
87+
if bias is not None and not isinstance(bias, GGMLTensor):
88+
bias = bias.to(dtype=input.dtype)
89+
return weight, bias
90+
7691
def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor:
7792
return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights)
7893

7994
def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor:
80-
weight = cast_to_device(self.weight, input.device)
81-
bias = cast_to_device(self.bias, input.device)
95+
weight, bias = self._cast_weight_bias_for_input(input)
8296
return torch.nn.functional.linear(input, weight, bias)
8397

8498
def forward(self, input: torch.Tensor) -> torch.Tensor:
8599
if len(self._patches_and_weights) > 0:
86100
return self._autocast_forward_with_patches(input)
87101
elif self._device_autocasting_enabled:
88102
return self._autocast_forward(input)
103+
elif (
104+
input.is_floating_point()
105+
and self.weight.is_floating_point()
106+
and self.weight.dtype != input.dtype
107+
):
108+
weight, bias = self._cast_weight_bias_for_input(input)
109+
return torch.nn.functional.linear(input, weight, bias)
89110
else:
90111
return super().forward(input)

0 commit comments

Comments
 (0)