Skip to content

Commit 4335679

Browse files
committed
update
1 parent 44f64d2 commit 4335679

2 files changed

Lines changed: 33 additions & 8 deletions

File tree

src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None:
110110
self.patch_size = patch_size
111111
self.patch_method = patch_method
112112

113-
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
114-
self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False)
113+
wavelets = _WAVELETS.get(patch_method).clone()
114+
arange = torch.arange(wavelets.shape[0])
115+
116+
self.register_buffer("wavelets", wavelets, persistent=False)
117+
self.register_buffer("_arange", arange, persistent=False)
115118

116119
def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor:
117120
dtype = hidden_states.dtype
@@ -185,12 +188,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar"):
185188
self.patch_size = patch_size
186189
self.patch_method = patch_method
187190

188-
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
189-
self.register_buffer(
190-
"_arange",
191-
torch.arange(_WAVELETS[patch_method].shape[0]),
192-
persistent=False,
193-
)
191+
wavelets = _WAVELETS.get(patch_method).clone()
192+
arange = torch.arange(wavelets.shape[0])
193+
194+
self.register_buffer("wavelets", wavelets, persistent=False)
195+
self.register_buffer("_arange", arange, persistent=False)
194196

195197
def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor:
196198
device = hidden_states.device

tests/models/test_modeling_common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,28 @@ def test_fn(storage_dtype, compute_dtype):
15321532
def test_layerwise_casting_inference(self):
15331533
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
15341534

1535+
torch.manual_seed(0)
1536+
offload_type = "leaf_level"
1537+
record_stream = True
1538+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1539+
model = self.model_class(**init_dict)
1540+
model.eval()
1541+
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
1542+
with tempfile.TemporaryDirectory() as tmpdir:
1543+
model.enable_group_offload(
1544+
torch_device,
1545+
offload_type=offload_type,
1546+
offload_to_disk_path=tmpdir,
1547+
use_stream=True,
1548+
record_stream=record_stream,
1549+
**additional_kwargs,
1550+
)
1551+
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
1552+
self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.")
1553+
_ = model(**inputs_dict)[0]
1554+
1555+
del model, init_dict, inputs_dict
1556+
15351557
torch.manual_seed(0)
15361558
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
15371559
model = self.model_class(**config)
@@ -1575,6 +1597,7 @@ def test_layerwise_casting(storage_dtype, compute_dtype):
15751597
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
15761598

15771599
@require_torch_accelerator
1600+
@torch.no_grad()
15781601
def test_layerwise_casting_memory(self):
15791602
MB_TOLERANCE = 0.2
15801603
LEAST_COMPUTE_CAPABILITY = 8.0

0 commit comments

Comments
 (0)