-
Notifications
You must be signed in to change notification settings - Fork 124
Expand file tree
/
Copy pathtest_models.py
More file actions
306 lines (213 loc) · 10.8 KB
/
test_models.py
File metadata and controls
306 lines (213 loc) · 10.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import pathlib
from unittest.mock import patch
import typer.testing
from comfy_cli.command.models.models import app, check_civitai_url, check_huggingface_url, list_models
def _make_model_tree(tmp_path: pathlib.Path) -> pathlib.Path:
"""Create a realistic model directory tree and return its root."""
model_dir = tmp_path / "models"
(model_dir / "root_model.safetensors").parent.mkdir(parents=True, exist_ok=True)
(model_dir / "root_model.safetensors").write_bytes(b"x" * 100)
(model_dir / "checkpoints").mkdir()
(model_dir / "checkpoints" / "sd15.safetensors").write_bytes(b"x" * 200)
(model_dir / "loras" / "SD1.5").mkdir(parents=True)
(model_dir / "loras" / "SD1.5" / "detail.safetensors").write_bytes(b"x" * 300)
(model_dir / "empty_dir").mkdir()
return model_dir
def test_list_models_finds_files_in_subdirectories(tmp_path):
model_dir = _make_model_tree(tmp_path)
result = list_models(model_dir)
names = {f.name for f in result}
assert "sd15.safetensors" in names
deep = [f for f in result if f.name == "detail.safetensors"]
assert len(deep) == 1
assert deep[0].relative_to(model_dir) == pathlib.Path("loras/SD1.5/detail.safetensors")
def test_list_models_finds_root_level_files(tmp_path):
model_dir = _make_model_tree(tmp_path)
result = list_models(model_dir)
names = {f.name for f in result}
assert "root_model.safetensors" in names
def test_list_models_returns_empty_for_missing_directory(tmp_path):
assert list_models(tmp_path / "nonexistent") == []
def test_list_models_ignores_directories(tmp_path):
model_dir = _make_model_tree(tmp_path)
result = list_models(model_dir)
assert all(f.is_file() for f in result)
dir_names = {f.name for f in result}
assert "empty_dir" not in dir_names
assert "checkpoints" not in dir_names
runner = typer.testing.CliRunner()
def test_list_command_shows_type_column(tmp_path):
_make_model_tree(tmp_path)
with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path):
result = runner.invoke(app, ["list", "--relative-path", "models"])
assert result.exit_code == 0
assert "Type" in result.output
assert "checkpoints" in result.output
assert "loras/SD1.5" in result.output
assert "root_model.safetensors" in result.output
def test_remove_with_path_traversal_is_rejected(tmp_path):
model_dir = tmp_path / "models"
model_dir.mkdir()
(model_dir / "legit.bin").write_bytes(b"x")
secret = tmp_path / "secret.txt"
secret.write_text("sensitive")
with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path):
result = runner.invoke(
app,
["remove", "--relative-path", "models", "--model-names", "../secret.txt", "--confirm"],
)
assert secret.exists()
assert "Invalid model path" in result.output
def test_remove_deletes_model_in_subdirectory(tmp_path):
model_dir = _make_model_tree(tmp_path)
target = model_dir / "checkpoints" / "sd15.safetensors"
assert target.exists()
with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path):
result = runner.invoke(
app,
["remove", "--relative-path", "models", "--model-names", "checkpoints/sd15.safetensors", "--confirm"],
)
assert result.exit_code == 0
assert not target.exists()
def test_remove_rejects_directory_name(tmp_path):
_make_model_tree(tmp_path)
with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path):
result = runner.invoke(
app,
["remove", "--relative-path", "models", "--model-names", "checkpoints", "--confirm"],
)
assert (tmp_path / "models" / "checkpoints").is_dir()
assert "not found" in result.output
def test_remove_deletes_root_level_model(tmp_path):
model_dir = _make_model_tree(tmp_path)
target = model_dir / "root_model.safetensors"
assert target.exists()
with patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path):
result = runner.invoke(
app,
["remove", "--relative-path", "models", "--model-names", "root_model.safetensors", "--confirm"],
)
assert result.exit_code == 0
assert not target.exists()
def test_remove_interactive_shows_relative_paths(tmp_path):
_make_model_tree(tmp_path)
with (
patch("comfy_cli.command.models.models.get_workspace", return_value=tmp_path),
patch("comfy_cli.command.models.models.ui") as mock_ui,
):
mock_ui.prompt_multi_select.return_value = ["checkpoints/sd15.safetensors"]
mock_ui.prompt_confirm_action.return_value = True
runner.invoke(app, ["remove", "--relative-path", "models"])
choices = mock_ui.prompt_multi_select.call_args[0][1]
assert "checkpoints/sd15.safetensors" in choices
assert "loras/SD1.5/detail.safetensors" in choices
assert "root_model.safetensors" in choices
assert not (tmp_path / "models" / "checkpoints" / "sd15.safetensors").exists()
def test_valid_model_url():
url = "https://civitai.com/models/43331"
assert check_civitai_url(url) == (True, False, 43331, None)
def test_valid_model_url_with_version():
url = "https://civitai.com/models/43331/majicmix-realistic"
assert check_civitai_url(url) == (True, False, 43331, None)
def test_valid_model_url_with_version_and_additional_segments():
url = "https://civitai.com/models/43331/majicmix-realistic/extra"
assert check_civitai_url(url) == (True, False, 43331, None)
def test_valid_model_url_with_query():
url = "https://civitai.com/models/43331?version=12345"
assert check_civitai_url(url) == (True, False, 43331, 12345)
def test_valid_api_url():
url = "https://civitai.com/api/download/models/67890"
assert check_civitai_url(url) == (False, True, None, 67890)
def test_invalid_url():
url = "https://example.com/models/43331"
assert check_civitai_url(url) == (False, False, None, None)
def test_malformed_url():
url = "https://civitai.com/models/"
assert check_civitai_url(url) == (False, False, None, None)
def test_invalid_model_id_url():
url = "https://civitai.com/models/invalid_id"
assert check_civitai_url(url) == (False, False, None, None)
def test_malformed_query_url():
url = "https://civitai.com/models/43331?version="
assert check_civitai_url(url) == (True, False, 43331, None)
def test_model_url_with_model_version_id_query():
url = "https://civitai.com/models/43331?modelVersionId=485088"
assert check_civitai_url(url) == (True, False, 43331, 485088)
def test_model_url_with_model_version_id_invalid():
url = "https://civitai.com/models/43331?modelVersionId=abc"
assert check_civitai_url(url) == (True, False, 43331, None)
def test_valid_api_v1_model_versions_url():
url = "https://civitai.com/api/v1/model-versions/1617665"
assert check_civitai_url(url) == (False, True, None, 1617665)
def test_valid_api_v1_model_versions_camelcase_segment():
url = "https://civitai.com/api/v1/modelVersions/1617665"
assert check_civitai_url(url) == (False, True, None, 1617665)
def test_valid_api_download_with_query_params():
url = "https://civitai.com/api/download/models/1617665?type=Model&format=SafeTensor"
assert check_civitai_url(url) == (False, True, None, 1617665)
def test_api_download_trailing_slash_is_ok():
url = "https://civitai.com/api/download/models/1617665/"
assert check_civitai_url(url) == (False, True, None, 1617665)
def test_api_download_non_numeric_id_models_version():
url = "https://civitai.com/api/v1/modelVersions/notanumber"
assert check_civitai_url(url) == (False, True, None, None)
def test_api_download_non_numeric_id():
url = "https://civitai.com/api/download/models/notanumber"
assert check_civitai_url(url) == (False, True, None, None)
def test_model_url_with_slug_and_query():
url = "https://civitai.com/models/43331/majicmix-realistic?modelVersionId=485088"
assert check_civitai_url(url) == (True, False, 43331, 485088)
def test_www_subdomain_is_accepted():
url = "https://www.civitai.com/models/43331?version=12345"
assert check_civitai_url(url) == (True, False, 43331, 12345)
def test_completly_mailformed_civitai_url():
url = "https://civitai.com/"
assert check_civitai_url(url) == (False, False, None, None)
def test_non_evil_civitai_url():
url = "https://evilcivitai.com/models/43331?version=12345"
assert check_civitai_url(url) == (False, False, None, None)
def test_valid_huggingface_url():
url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4.ckpt", None, "main")
def test_valid_huggingface_url_sd_audio():
url = "https://huggingface.co/stabilityai/stable-audio-open-1.0/blob/main/model.safetensors"
assert check_huggingface_url(url) == (True, "stabilityai/stable-audio-open-1.0", "model.safetensors", None, "main")
def test_valid_huggingface_url_with_folder():
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt"
assert check_huggingface_url(url) == (
True,
"runwayml/stable-diffusion-v1-5",
"v1-5-pruned-emaonly.ckpt",
None,
"main",
)
def test_valid_huggingface_url_with_subfolder():
url = "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt"
assert check_huggingface_url(url) == (
True,
"stabilityai/stable-diffusion-2-1",
"v2-1_768-ema-pruned.ckpt",
None,
"main",
)
def test_valid_huggingface_url_with_encoded_filename():
url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4%20(1).ckpt"
assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4 (1).ckpt", None, "main")
def test_invalid_huggingface_url():
url = "https://example.com/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (False, None, None, None, None)
def test_invalid_huggingface_url_structure():
url = "https://huggingface.co/CompVis/stable-diffusion-v1-4/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (False, None, None, None, None)
def test_huggingface_url_with_com_domain():
url = "https://huggingface.com/CompVis/stable-diffusion-v1-4/resolve/main/sd-v1-4.ckpt"
assert check_huggingface_url(url) == (True, "CompVis/stable-diffusion-v1-4", "sd-v1-4.ckpt", None, "main")
def test_huggingface_url_with_folder_structure():
url = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors"
assert check_huggingface_url(url) == (
True,
"stabilityai/stable-diffusion-xl-base-1.0",
"sd_xl_base_1.0.safetensors",
None,
"main",
)