|
14 | 14 | ) |
15 | 15 | from invokeai.backend.patches.layer_patcher import LayerPatcher |
16 | 16 | from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch |
| 17 | +from invokeai.backend.patches.layers.dora_layer import DoRALayer |
17 | 18 | from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer |
18 | 19 | from invokeai.backend.patches.layers.lokr_layer import LoKRLayer |
19 | 20 | from invokeai.backend.patches.layers.lora_layer import LoRALayer |
@@ -346,6 +347,7 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La |
346 | 347 | "concatenated_lora", |
347 | 348 | "flux_control_lora", |
348 | 349 | "single_lokr", |
| 350 | + "single_dora", |
349 | 351 | ] |
350 | 352 | ) |
351 | 353 | def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest: |
@@ -432,6 +434,20 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest: |
432 | 434 | ) |
433 | 435 | input = torch.randn(1, in_features) |
434 | 436 | return ([(lokr_layer, 0.7)], input) |
| 437 | + elif layer_type == "single_dora": |
| 438 | + # Regression coverage for #8624: DoRA + partial-loading + CPU->device autocast. |
| 439 | + # Scaled down so the patched weight stays well-conditioned for allclose comparisons. |
| 440 | + # dora_scale has shape (1, in_features) to broadcast against direction_norm in |
| 441 | + # DoRALayer.get_weight — see dora_layer.py:74-82. |
| 442 | + dora_layer = DoRALayer( |
| 443 | + up=torch.randn(out_features, rank) * 0.01, |
| 444 | + down=torch.randn(rank, in_features) * 0.01, |
| 445 | + dora_scale=torch.ones(1, in_features), |
| 446 | + alpha=1.0, |
| 447 | + bias=torch.randn(out_features) * 0.01, |
| 448 | + ) |
| 449 | + input = torch.randn(1, in_features) |
| 450 | + return ([(dora_layer, 0.7)], input) |
435 | 451 | else: |
436 | 452 | raise ValueError(f"Unsupported layer_type: {layer_type}") |
437 | 453 |
|
@@ -676,3 +692,45 @@ def test_conv2d_mixed_dtype_sidecar_parameter_patch(dtype: torch.dtype): |
676 | 692 |
|
677 | 693 | assert output.dtype == input.dtype |
678 | 694 | assert output.shape == (2, 16, 3, 3) |
| 695 | + |
| 696 | + |
| 697 | +@torch.no_grad() |
| 698 | +def test_aggregate_patch_parameters_preserves_plain_tensor_with_dora(): |
| 699 | + """Regression test for #8624: when partial-loading autocasts a CPU Parameter onto the |
| 700 | + compute device, cast_to_device returns a plain torch.Tensor (not a Parameter). The |
| 701 | + aggregator must treat that as a real tensor and not substitute a meta-device dummy — |
| 702 | + otherwise DoRA's quantization guard falsely triggers on non-quantized base models. |
| 703 | +
|
| 704 | + This test is CPU-only and simulates the hand-off by constructing a plain torch.Tensor |
| 705 | + directly; the equivalent CUDA/MPS E2E flow is exercised by the "single_dora" variant |
| 706 | + of test_linear_sidecar_patches_with_autocast_from_cpu_to_device. |
| 707 | + """ |
| 708 | + layer = wrap_single_custom_layer(torch.nn.Linear(32, 64)) |
| 709 | + |
| 710 | + rank = 4 |
| 711 | + dora_patch = DoRALayer( |
| 712 | + up=torch.randn(64, rank) * 0.01, |
| 713 | + down=torch.randn(rank, 32) * 0.01, |
| 714 | + dora_scale=torch.ones(1, 32), |
| 715 | + alpha=1.0, |
| 716 | + bias=None, |
| 717 | + ) |
| 718 | + |
| 719 | + # Plain torch.Tensor — the shape _cast_weight_bias_for_input hands into |
| 720 | + # _aggregate_patch_parameters after autocasting a Parameter across devices. |
| 721 | + plain_weight = torch.randn(64, 32) |
| 722 | + assert type(plain_weight) is torch.Tensor |
| 723 | + |
| 724 | + orig_params = {"weight": plain_weight} |
| 725 | + params = layer._aggregate_patch_parameters( |
| 726 | + patches_and_weights=[(dora_patch, 1.0)], |
| 727 | + orig_params=orig_params, |
| 728 | + device=torch.device("cpu"), |
| 729 | + ) |
| 730 | + |
| 731 | + # Pre-fix, orig_params["weight"] would have been replaced by a meta-device dummy, |
| 732 | + # causing DoRALayer.get_parameters to raise "not compatible with DoRA patches". |
| 733 | + assert orig_params["weight"].device.type == "cpu" |
| 734 | + assert params["weight"].shape == (64, 32) |
| 735 | + assert params["weight"].device.type == "cpu" |
| 736 | + assert not torch.isnan(params["weight"]).any() |
0 commit comments