Skip to content

Commit 620e370

Browse files
committed
chore(tests): fix dreambooth tests
1 parent ce3827f commit 620e370

File tree

1 file changed

+64
-52
lines changed

1 file changed

+64
-52
lines changed
Lines changed: 64 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,67 @@
1-
from .lib import getMinio, getDDA, AWS_S3_DEFAULT_BUCKET
1+
from .lib import getMinio, getDDA
22
from test import runTest
33

44

5-
def test_training_s3():
6-
minio = getMinio("global")
7-
dda = getDDA(
8-
minio=minio,
9-
stream_logs=True,
10-
)
11-
print(dda)
12-
13-
# fp32 model is obviously bigger
14-
result = runTest(
15-
"dreambooth",
16-
{"test_url": dda.url},
17-
{
18-
"MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
19-
"MODEL_REVISION": "",
20-
"MODEL_PRECISION": "",
21-
"MODEL_URL": "s3://",
22-
"train": "dreambooth",
23-
"dest_url": f"s3:///{AWS_S3_DEFAULT_BUCKET}/model.tar.zst",
24-
},
25-
{"max_train_steps": 1},
26-
)
27-
28-
dda.stop()
29-
minio.stop()
30-
timings = result["$timings"]
31-
assert timings["training"] > 0
32-
assert timings["upload"] > 0
33-
34-
35-
def test_inference():
36-
dda = getDDA(
37-
stream_logs=True,
38-
root_cache=False,
39-
)
40-
print(dda)
41-
42-
# fp32 model is obviously bigger
43-
result = runTest(
44-
"txt2img",
45-
{"test_url": dda.url},
46-
{
47-
"MODEL_ID": "model",
48-
"MODEL_PRECISION": "fp16",
49-
"MODEL_URL": f"s3:///{AWS_S3_DEFAULT_BUCKET}/model.tar.zst",
50-
},
51-
{"num_inference_steps": 1},
52-
)
53-
54-
dda.stop()
55-
assert result["image_base64"]
5+
class TestDreamBoothS3:
6+
"""
7+
Train/Infer via S3 model save.
8+
"""
9+
10+
def setup_class(self):
11+
print("setup_class")
12+
self.minio = getMinio("global")
13+
14+
def teardown_class(self):
15+
print("teardown_class")
16+
self.minio.stop()
17+
18+
def test_training_s3(self):
19+
dda = getDDA(
20+
minio=self.minio,
21+
stream_logs=True,
22+
)
23+
print(dda)
24+
25+
# fp32 model is obviously bigger
26+
result = runTest(
27+
"dreambooth",
28+
{"test_url": dda.url},
29+
{
30+
"MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
31+
"MODEL_REVISION": "",
32+
"MODEL_PRECISION": "",
33+
"MODEL_URL": "s3://",
34+
"train": "dreambooth",
35+
"dest_url": f"s3:///{self.minio.aws_s3_default_bucket}/model.tar.zst",
36+
},
37+
{"max_train_steps": 1},
38+
)
39+
40+
dda.stop()
41+
timings = result["$timings"]
42+
assert timings["training"] > 0
43+
assert timings["upload"] > 0
44+
45+
# dependent on above, TODO, mark as such.
46+
def test_s3_download_and_inference(self):
47+
dda = getDDA(
48+
minio=self.minio,
49+
stream_logs=True,
50+
root_cache=False,
51+
)
52+
print(dda)
53+
54+
# fp32 model is obviously bigger
55+
result = runTest(
56+
"txt2img",
57+
{"test_url": dda.url},
58+
{
59+
"MODEL_ID": "model",
60+
"MODEL_PRECISION": "fp16",
61+
"MODEL_URL": f"s3:///{self.minio.aws_s3_default_bucket}/model.tar.zst",
62+
},
63+
{"num_inference_steps": 1},
64+
)
65+
66+
dda.stop()
67+
assert result["image_base64"]

0 commit comments

Comments
 (0)