Skip to content

Commit ce3827f

Browse files
committed
fix(dreambooth): runtime_dls path fix; integration tests
1 parent bef109f commit ce3827f

2 files changed

Lines changed: 60 additions & 0 deletions

File tree

api/app.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,11 @@ def inference(all_inputs: dict) -> dict:
339339
"message": 'Called with callInput { train: "dreambooth" } but built with USE_DREAMBOOTH=0',
340340
}
341341
}
342+
343+
if RUNTIME_DOWNLOADS:
344+
if os.path.isdir(model_dir):
345+
normalized_model_id = model_dir
346+
342347
torch.set_grad_enabled(True)
343348
result = result | TrainDreamBooth(
344349
normalized_model_id, pipeline, model_inputs, call_inputs
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from .lib import getMinio, getDDA, AWS_S3_DEFAULT_BUCKET
2+
from test import runTest
3+
4+
5+
def test_training_s3():
6+
minio = getMinio("global")
7+
dda = getDDA(
8+
minio=minio,
9+
stream_logs=True,
10+
)
11+
print(dda)
12+
13+
# fp32 model is obviously bigger
14+
result = runTest(
15+
"dreambooth",
16+
{"test_url": dda.url},
17+
{
18+
"MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
19+
"MODEL_REVISION": "",
20+
"MODEL_PRECISION": "",
21+
"MODEL_URL": "s3://",
22+
"train": "dreambooth",
23+
"dest_url": f"s3:///{AWS_S3_DEFAULT_BUCKET}/model.tar.zst",
24+
},
25+
{"max_train_steps": 1},
26+
)
27+
28+
dda.stop()
29+
minio.stop()
30+
timings = result["$timings"]
31+
assert timings["training"] > 0
32+
assert timings["upload"] > 0
33+
34+
35+
def test_inference():
36+
dda = getDDA(
37+
stream_logs=True,
38+
root_cache=False,
39+
)
40+
print(dda)
41+
42+
# fp32 model is obviously bigger
43+
result = runTest(
44+
"txt2img",
45+
{"test_url": dda.url},
46+
{
47+
"MODEL_ID": "model",
48+
"MODEL_PRECISION": "fp16",
49+
"MODEL_URL": f"s3:///{AWS_S3_DEFAULT_BUCKET}/model.tar.zst",
50+
},
51+
{"num_inference_steps": 1},
52+
)
53+
54+
dda.stop()
55+
assert result["image_base64"]

0 commit comments

Comments
 (0)