Skip to content

Commit 5c11f5d

Browse files
author
Talmaj Marinc
committed
Polish imports and modify asserts to raise proper errors with messages.
1 parent 713b557 commit 5c11f5d

1 file changed

Lines changed: 68 additions & 19 deletions

File tree

comfy_extras/void_noise_warp.py

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.nn.functional as F
2121
from einops import rearrange
22+
from torchvision.models.optical_flow import raft_large
2223

2324
import comfy.model_management
2425

@@ -35,7 +36,10 @@ def _torch_resize_chw(image, size, interp, copy=True):
3536
the requested size matches the input, returns the input tensor as is
3637
(faster but callers must not mutate the result).
3738
"""
38-
assert image.ndim == 3, image.shape
39+
if image.ndim != 3:
40+
raise ValueError(
41+
f"_torch_resize_chw expects a 3D CHW tensor, got shape {tuple(image.shape)}"
42+
)
3943
_, in_h, in_w = image.shape
4044
if isinstance(size, (int, float)) and not isinstance(size, bool):
4145
new_h = max(1, int(in_h * size))
@@ -59,8 +63,14 @@ def _torch_remap_relative(image, dx, dy, interp="bilinear"):
5963
Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
6064
for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
6165
"""
62-
assert image.ndim == 3
63-
assert dx.shape == dy.shape
66+
if image.ndim != 3:
67+
raise ValueError(
68+
f"_torch_remap_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
69+
)
70+
if dx.shape != dy.shape:
71+
raise ValueError(
72+
f"_torch_remap_relative: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
73+
)
6474
_, h, w = image.shape
6575

6676
x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype)
@@ -82,9 +92,16 @@ def _torch_scatter_add_relative(image, dx, dy):
8292
Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
8393
interp='floor')``. Out-of-bounds targets are dropped.
8494
"""
85-
assert image.ndim == 3
95+
if image.ndim != 3:
96+
raise ValueError(
97+
f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
98+
)
8699
in_c, in_h, in_w = image.shape
87-
assert dx.shape == dy.shape == (in_h, in_w)
100+
if dx.shape != (in_h, in_w) or dy.shape != (in_h, in_w):
101+
raise ValueError(
102+
f"_torch_scatter_add_relative: dx/dy must be ({in_h}, {in_w}), "
103+
f"got dx={tuple(dx.shape)} dy={tuple(dy.shape)}"
104+
)
88105

89106
x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long)
90107
y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None]
@@ -185,11 +202,20 @@ def warp_state(state, flow):
185202
``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
186203
``flow`` has shape ``(2, h, w)`` (= dx, dy).
187204
"""
188-
assert flow.device == state.device
189-
assert flow.ndim == 3 and flow.shape[0] == 2
190-
assert state.ndim == 3
205+
if flow.device != state.device:
206+
raise ValueError(
207+
f"warp_state: flow and state must be on the same device, "
208+
f"got flow={flow.device} state={state.device}"
209+
)
210+
if state.ndim != 3:
211+
raise ValueError(
212+
f"warp_state: state must be 3D (3+C, H, W), got shape {tuple(state.shape)}"
213+
)
191214
xyoc, h, w = state.shape
192-
assert flow.shape == (2, h, w)
215+
if flow.shape != (2, h, w):
216+
raise ValueError(
217+
f"warp_state: flow must have shape (2, {h}, {w}), got {tuple(flow.shape)}"
218+
)
193219
device = state.device
194220

195221
x_ch, y_ch = 0, 1
@@ -198,8 +224,12 @@ def warp_state(state, flow):
198224
w_ch = 2 # state[w_ch] = ω
199225
c = xyoc - xyw
200226
oc = xyoc - xy
201-
assert c > 0, "state has no noise channels"
202-
assert (state[w_ch] > 0).all(), "all weights must be > 0"
227+
if c <= 0:
228+
raise ValueError(
229+
f"warp_state: state has no noise channels (expected 3+C with C>0, got {xyoc} channels)"
230+
)
231+
if not (state[w_ch] > 0).all():
232+
raise ValueError("warp_state: all weights in state[2] must be > 0")
203233

204234
grid = xy_meshgrid_like_image(state)
205235

@@ -267,7 +297,10 @@ class NoiseWarper:
267297
"""
268298

269299
def __init__(self, c, h, w, device, dtype=torch.float32):
270-
assert c > 0 and h > 0 and w > 0
300+
if c <= 0 or h <= 0 or w <= 0:
301+
raise ValueError(
302+
f"NoiseWarper: c/h/w must all be positive, got c={c} h={h} w={w}"
303+
)
271304
self.c = c
272305
self.h = h
273306
self.w = w
@@ -287,7 +320,10 @@ def noise(self):
287320
return n * weights / (weights ** 2).sqrt()
288321

289322
def __call__(self, dx, dy):
290-
assert dx.shape == dy.shape
323+
if dx.shape != dy.shape:
324+
raise ValueError(
325+
f"NoiseWarper: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
326+
)
291327
flow = torch.stack([dx, dy]).to(self.device, self.dtype)
292328
_, oflowh, ofloww = flow.shape
293329

@@ -312,8 +348,6 @@ class RaftOpticalFlow:
312348
"""Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow."""
313349

314350
def __init__(self, device=None):
315-
from torchvision.models.optical_flow import raft_large
316-
317351
if device is None:
318352
device = comfy.model_management.get_torch_device()
319353
device = torch.device(device) if not isinstance(device, torch.device) else device
@@ -334,7 +368,11 @@ def _preprocess(self, image_chw):
334368

335369
def __call__(self, from_image, to_image):
336370
"""``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
337-
assert from_image.shape == to_image.shape
371+
if from_image.shape != to_image.shape:
372+
raise ValueError(
373+
f"RaftOpticalFlow: from_image and to_image must match, "
374+
f"got {tuple(from_image.shape)} vs {tuple(to_image.shape)}"
375+
)
338376
_, h, w = from_image.shape
339377
with torch.no_grad():
340378
img1 = self._preprocess(from_image)
@@ -385,9 +423,20 @@ def get_noise_from_video(
385423
Returns:
386424
``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
387425
"""
388-
assert isinstance(resize_flow, int) and resize_flow >= 1, resize_flow
389-
assert video_frames.ndim == 4 and video_frames.shape[-1] == 3, video_frames.shape
390-
assert video_frames.dtype == torch.uint8, video_frames.dtype
426+
if not isinstance(resize_flow, int) or resize_flow < 1:
427+
raise ValueError(
428+
f"get_noise_from_video: resize_flow must be a positive int, got {resize_flow!r}"
429+
)
430+
if video_frames.ndim != 4 or video_frames.shape[-1] != 3:
431+
raise ValueError(
432+
"get_noise_from_video: video_frames must have shape (T, H, W, 3), "
433+
f"got {tuple(video_frames.shape)}"
434+
)
435+
if video_frames.dtype != torch.uint8:
436+
raise TypeError(
437+
"get_noise_from_video: video_frames must be uint8 in [0, 255], "
438+
f"got dtype {video_frames.dtype}"
439+
)
391440

392441
if device is None:
393442
device = comfy.model_management.get_torch_device()

0 commit comments

Comments
 (0)