Skip to content

Commit 07552db

Browse files
committed
Add SeedVR2 integration coverage
1 parent 764d7aa commit 07552db

2 files changed

Lines changed: 229 additions & 3 deletions

File tree

.github/workflows/test-unit.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@ name: Unit Tests
22

33
on:
44
push:
5-
branches: [ main, master, release/** ]
5+
branches: [ main, master, develop, release/** ]
66
pull_request:
7-
branches: [ main, master, release/** ]
7+
branches: [ main, master, develop, release/** ]
88

99
jobs:
1010
test:
1111
strategy:
1212
matrix:
1313
os: [ubuntu-latest, windows-2022, macos-latest]
1414
runs-on: ${{ matrix.os }}
15-
continue-on-error: true
1615
steps:
1716
- uses: actions/checkout@v4
1817
- name: Set up Python
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import torch
2+
3+
from comfy.cli_args import args as cli_args
4+
5+
if not torch.cuda.is_available():
6+
cli_args.cpu = True
7+
8+
import comfy_extras.nodes_seedvr as nodes_seedvr
9+
import nodes
10+
11+
12+
def test_seedvr2_postprocessing_restores_flat_decoded_batch_time():
13+
decoded = torch.arange(6 * 4 * 6 * 1, dtype=torch.float32).reshape(6, 4, 6, 1)
14+
original = torch.ones((2, 3, 4, 6, 1), dtype=torch.float32)
15+
16+
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "none").result[0]
17+
18+
assert output.shape == (6, 4, 6, 1)
19+
torch.testing.assert_close(output, decoded)
20+
21+
22+
def test_seedvr2_postprocessing_crops_to_resized_original_size():
23+
decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32)
24+
original = torch.full((1, 1, 120, 169, 3), 0.25, dtype=torch.float32)
25+
26+
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
27+
28+
assert output.shape == (1, 120, 168, 3)
29+
30+
31+
def test_seedvr2_postprocessing_uses_decoded_size_when_resized_original_is_larger():
32+
decoded = torch.ones((1, 128, 160, 3), dtype=torch.float32)
33+
original = torch.full((1, 1, 480, 640, 3), 0.25, dtype=torch.float32)
34+
35+
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 480, "none").result[0]
36+
37+
assert output.shape == (1, 128, 160, 3)
38+
39+
40+
def test_seedvr2_postprocessing_does_not_trim_real_black_original_edges():
41+
decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32)
42+
original = torch.zeros((1, 1, 128, 176, 3), dtype=torch.float32)
43+
44+
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 128, "none").result[0]
45+
46+
assert output.shape == (1, 128, 176, 3)
47+
48+
49+
def test_seedvr2_postprocessing_crops_height_only_to_resized_original_size():
50+
decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32)
51+
original = torch.full((1, 1, 120, 176, 3), 0.25, dtype=torch.float32)
52+
53+
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "none").result[0]
54+
55+
assert output.shape == (1, 120, 176, 3)
56+
57+
58+
def test_seedvr2_postprocessing_lab_uses_resized_original_size(monkeypatch):
59+
decoded = torch.ones((1, 128, 176, 3), dtype=torch.float32)
60+
original = torch.full((1, 1, 120, 169, 3), 0.25, dtype=torch.float32)
61+
calls = []
62+
63+
def fake_lab_color_transfer(decoded_flat, reference_flat):
64+
calls.append((tuple(decoded_flat.shape), tuple(reference_flat.shape)))
65+
return decoded_flat
66+
67+
monkeypatch.setattr(nodes_seedvr, "lab_color_transfer", fake_lab_color_transfer)
68+
69+
output = nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 120, "lab").result[0]
70+
71+
assert calls == [((1, 3, 120, 169), (1, 3, 120, 169))]
72+
assert output.shape == (1, 120, 168, 3)
73+
74+
75+
def test_seedvr2_tiled_decode_node_ignores_seedvr2_sideband_metadata():
76+
class FakeVAE:
77+
def __init__(self):
78+
self.decode_call = None
79+
80+
def temporal_compression_decode(self):
81+
return 4
82+
83+
def spacial_compression_decode(self):
84+
return 8
85+
86+
def decode_tiled(self, samples, **kwargs):
87+
self.decode_call = kwargs
88+
return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32)
89+
90+
vae = FakeVAE()
91+
samples = {
92+
"samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32),
93+
"seedvr2_channel_last": True,
94+
}
95+
96+
nodes.VAEDecodeTiled().decode(
97+
vae,
98+
samples,
99+
tile_size=64,
100+
overlap=0,
101+
temporal_size=64,
102+
temporal_overlap=8,
103+
)
104+
105+
assert "seedvr2_channel_last" not in vae.decode_call
106+
107+
108+
def test_seedvr2_decode_node_ignores_seedvr2_sideband_metadata():
109+
class FakeVAE:
110+
def __init__(self):
111+
self.decode_call = None
112+
113+
def decode(self, samples, **kwargs):
114+
self.decode_call = kwargs
115+
return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32)
116+
117+
vae = FakeVAE()
118+
samples = {
119+
"samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32),
120+
"seedvr2_channel_last": True,
121+
}
122+
123+
nodes.VAEDecode().decode(vae, samples)
124+
125+
assert "seedvr2_channel_last" not in vae.decode_call
126+
127+
128+
def test_seedvr2_decode_node_leaves_unmarked_ambiguous_latent_unforced():
129+
class FakeVAE:
130+
def __init__(self):
131+
self.decode_call = None
132+
133+
def decode(self, samples, **kwargs):
134+
self.decode_call = kwargs
135+
return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32)
136+
137+
vae = FakeVAE()
138+
samples = {"samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32)}
139+
140+
nodes.VAEDecode().decode(vae, samples)
141+
142+
assert "seedvr2_channel_last" not in vae.decode_call
143+
144+
145+
def test_seedvr2_encode_node_does_not_mark_model_specific_layout_metadata():
146+
class FakeVAE:
147+
def encode(self, pixels):
148+
return torch.zeros((1, 16, 2, 3, 4), dtype=torch.float32)
149+
150+
output = nodes.VAEEncode().encode(FakeVAE(), torch.zeros((1, 8, 8, 3)))[0]
151+
152+
assert set(output) == {"samples"}
153+
154+
155+
def test_seedvr2_tiled_encode_node_does_not_mark_model_specific_layout_metadata():
156+
class FakeVAE:
157+
def encode_tiled(self, pixels, **kwargs):
158+
return torch.zeros((1, 16, 2, 3, 4), dtype=torch.float32)
159+
160+
output = nodes.VAEEncodeTiled().encode(FakeVAE(), torch.zeros((1, 8, 8, 3)), 64, 0)[0]
161+
162+
assert set(output) == {"samples"}
163+
164+
165+
def test_seedvr2_saved_latent_does_not_persist_model_specific_layout_metadata(monkeypatch):
166+
saved = {}
167+
168+
def fake_save_image_path(filename_prefix, output_dir):
169+
return output_dir, filename_prefix, 1, "", filename_prefix
170+
171+
def fake_save_torch_file(output, file, metadata=None):
172+
saved.update(output)
173+
174+
monkeypatch.setattr(nodes.folder_paths, "get_save_image_path", fake_save_image_path)
175+
monkeypatch.setattr(nodes.comfy.utils, "save_torch_file", fake_save_torch_file)
176+
monkeypatch.setattr(nodes.folder_paths, "get_annotated_filepath", lambda latent: latent)
177+
monkeypatch.setattr(nodes.safetensors.torch, "load_file", lambda latent_path, device="cpu": saved)
178+
179+
original = torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32)
180+
nodes.SaveLatent().save({"samples": original, "seedvr2_channel_last": True}, "seedvr2_latent")
181+
loaded = nodes.LoadLatent().load("seedvr2_latent")[0]
182+
183+
assert "seedvr2_channel_last" not in saved
184+
assert "seedvr2_channel_last" not in loaded
185+
torch.testing.assert_close(loaded["samples"], original)
186+
187+
188+
def test_seedvr2_tiled_decode_node_preserves_legacy_decode_tiled_signature():
189+
class FakeVAE:
190+
def __init__(self):
191+
self.decode_call = None
192+
193+
def temporal_compression_decode(self):
194+
return 4
195+
196+
def spacial_compression_decode(self):
197+
return 8
198+
199+
def decode_tiled(self, samples, tile_x, tile_y, overlap, tile_t, overlap_t):
200+
self.decode_call = {
201+
"tile_x": tile_x,
202+
"tile_y": tile_y,
203+
"overlap": overlap,
204+
"tile_t": tile_t,
205+
"overlap_t": overlap_t,
206+
}
207+
return torch.zeros((1, 1, 2, 2, 3), dtype=torch.float32)
208+
209+
vae = FakeVAE()
210+
samples = {"samples": torch.zeros((1, 16, 4, 4, 16), dtype=torch.float32)}
211+
212+
nodes.VAEDecodeTiled().decode(
213+
vae,
214+
samples,
215+
tile_size=64,
216+
overlap=0,
217+
temporal_size=64,
218+
temporal_overlap=8,
219+
)
220+
221+
assert vae.decode_call == {
222+
"tile_x": 8,
223+
"tile_y": 8,
224+
"overlap": 0,
225+
"tile_t": 16,
226+
"overlap_t": 2,
227+
}

0 commit comments

Comments
 (0)