1919import torch
2020import torch .nn .functional as F
2121from einops import rearrange
22+ from torchvision .models .optical_flow import raft_large
2223
2324import comfy .model_management
2425
@@ -35,7 +36,10 @@ def _torch_resize_chw(image, size, interp, copy=True):
3536 the requested size matches the input, returns the input tensor as is
3637 (faster but callers must not mutate the result).
3738 """
38- assert image .ndim == 3 , image .shape
39+ if image .ndim != 3 :
40+ raise ValueError (
41+ f"_torch_resize_chw expects a 3D CHW tensor, got shape { tuple (image .shape )} "
42+ )
3943 _ , in_h , in_w = image .shape
4044 if isinstance (size , (int , float )) and not isinstance (size , bool ):
4145 new_h = max (1 , int (in_h * size ))
@@ -59,8 +63,14 @@ def _torch_remap_relative(image, dx, dy, interp="bilinear"):
5963 Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
6064 for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
6165 """
62- assert image .ndim == 3
63- assert dx .shape == dy .shape
66+ if image .ndim != 3 :
67+ raise ValueError (
68+ f"_torch_remap_relative expects a 3D CHW tensor, got shape { tuple (image .shape )} "
69+ )
70+ if dx .shape != dy .shape :
71+ raise ValueError (
72+ f"_torch_remap_relative: dx and dy must match, got { tuple (dx .shape )} vs { tuple (dy .shape )} "
73+ )
6474 _ , h , w = image .shape
6575
6676 x_abs = dx + torch .arange (w , device = dx .device , dtype = dx .dtype )
@@ -82,9 +92,16 @@ def _torch_scatter_add_relative(image, dx, dy):
8292 Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
8393 interp='floor')``. Out-of-bounds targets are dropped.
8494 """
85- assert image .ndim == 3
95+ if image .ndim != 3 :
96+ raise ValueError (
97+ f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape { tuple (image .shape )} "
98+ )
8699 in_c , in_h , in_w = image .shape
87- assert dx .shape == dy .shape == (in_h , in_w )
100+ if dx .shape != (in_h , in_w ) or dy .shape != (in_h , in_w ):
101+ raise ValueError (
102+ f"_torch_scatter_add_relative: dx/dy must be ({ in_h } , { in_w } ), "
103+ f"got dx={ tuple (dx .shape )} dy={ tuple (dy .shape )} "
104+ )
88105
89106 x = dx .long () + torch .arange (in_w , device = dx .device , dtype = torch .long )
90107 y = dy .long () + torch .arange (in_h , device = dy .device , dtype = torch .long )[:, None ]
@@ -185,11 +202,20 @@ def warp_state(state, flow):
185202 ``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
186203 ``flow`` has shape ``(2, h, w)`` (= dx, dy).
187204 """
188- assert flow .device == state .device
189- assert flow .ndim == 3 and flow .shape [0 ] == 2
190- assert state .ndim == 3
205+ if flow .device != state .device :
206+ raise ValueError (
207+ f"warp_state: flow and state must be on the same device, "
208+ f"got flow={ flow .device } state={ state .device } "
209+ )
210+ if state .ndim != 3 :
211+ raise ValueError (
212+ f"warp_state: state must be 3D (3+C, H, W), got shape { tuple (state .shape )} "
213+ )
191214 xyoc , h , w = state .shape
192- assert flow .shape == (2 , h , w )
215+ if flow .shape != (2 , h , w ):
216+ raise ValueError (
217+ f"warp_state: flow must have shape (2, { h } , { w } ), got { tuple (flow .shape )} "
218+ )
193219 device = state .device
194220
195221 x_ch , y_ch = 0 , 1
@@ -198,8 +224,12 @@ def warp_state(state, flow):
198224 w_ch = 2 # state[w_ch] = ω
199225 c = xyoc - xyw
200226 oc = xyoc - xy
201- assert c > 0 , "state has no noise channels"
202- assert (state [w_ch ] > 0 ).all (), "all weights must be > 0"
227+ if c <= 0 :
228+ raise ValueError (
229+ f"warp_state: state has no noise channels (expected 3+C with C>0, got { xyoc } channels)"
230+ )
231+ if not (state [w_ch ] > 0 ).all ():
232+ raise ValueError ("warp_state: all weights in state[2] must be > 0" )
203233
204234 grid = xy_meshgrid_like_image (state )
205235
@@ -267,7 +297,10 @@ class NoiseWarper:
267297 """
268298
269299 def __init__ (self , c , h , w , device , dtype = torch .float32 ):
270- assert c > 0 and h > 0 and w > 0
300+ if c <= 0 or h <= 0 or w <= 0 :
301+ raise ValueError (
302+ f"NoiseWarper: c/h/w must all be positive, got c={ c } h={ h } w={ w } "
303+ )
271304 self .c = c
272305 self .h = h
273306 self .w = w
@@ -287,7 +320,10 @@ def noise(self):
287320 return n * weights / (weights ** 2 ).sqrt ()
288321
289322 def __call__ (self , dx , dy ):
290- assert dx .shape == dy .shape
323+ if dx .shape != dy .shape :
324+ raise ValueError (
325+ f"NoiseWarper: dx and dy must match, got { tuple (dx .shape )} vs { tuple (dy .shape )} "
326+ )
291327 flow = torch .stack ([dx , dy ]).to (self .device , self .dtype )
292328 _ , oflowh , ofloww = flow .shape
293329
@@ -312,8 +348,6 @@ class RaftOpticalFlow:
312348 """Torchvision RAFT-large wrapper. ``__call__`` returns a (2, H, W) flow."""
313349
314350 def __init__ (self , device = None ):
315- from torchvision .models .optical_flow import raft_large
316-
317351 if device is None :
318352 device = comfy .model_management .get_torch_device ()
319353 device = torch .device (device ) if not isinstance (device , torch .device ) else device
@@ -334,7 +368,11 @@ def _preprocess(self, image_chw):
334368
335369 def __call__ (self , from_image , to_image ):
336370 """``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
337- assert from_image .shape == to_image .shape
371+ if from_image .shape != to_image .shape :
372+ raise ValueError (
373+ f"RaftOpticalFlow: from_image and to_image must match, "
374+ f"got { tuple (from_image .shape )} vs { tuple (to_image .shape )} "
375+ )
338376 _ , h , w = from_image .shape
339377 with torch .no_grad ():
340378 img1 = self ._preprocess (from_image )
@@ -385,9 +423,20 @@ def get_noise_from_video(
385423 Returns:
386424 ``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
387425 """
388- assert isinstance (resize_flow , int ) and resize_flow >= 1 , resize_flow
389- assert video_frames .ndim == 4 and video_frames .shape [- 1 ] == 3 , video_frames .shape
390- assert video_frames .dtype == torch .uint8 , video_frames .dtype
426+ if not isinstance (resize_flow , int ) or resize_flow < 1 :
427+ raise ValueError (
428+ f"get_noise_from_video: resize_flow must be a positive int, got { resize_flow !r} "
429+ )
430+ if video_frames .ndim != 4 or video_frames .shape [- 1 ] != 3 :
431+ raise ValueError (
432+ "get_noise_from_video: video_frames must have shape (T, H, W, 3), "
433+ f"got { tuple (video_frames .shape )} "
434+ )
435+ if video_frames .dtype != torch .uint8 :
436+ raise TypeError (
437+ "get_noise_from_video: video_frames must be uint8 in [0, 255], "
438+ f"got dtype { video_frames .dtype } "
439+ )
391440
392441 if device is None :
393442 device = comfy .model_management .get_torch_device ()
0 commit comments