File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -339,6 +339,11 @@ def inference(all_inputs: dict) -> dict:
339339 "message" : 'Called with callInput { train: "dreambooth" } but built with USE_DREAMBOOTH=0' ,
340340 }
341341 }
342+
343+ if RUNTIME_DOWNLOADS :
344+ if os .path .isdir (model_dir ):
345+ normalized_model_id = model_dir
346+
342347 torch .set_grad_enabled (True )
343348 result = result | TrainDreamBooth (
344349 normalized_model_id , pipeline , model_inputs , call_inputs
Original file line number Diff line number Diff line change 1+ from .lib import getMinio , getDDA , AWS_S3_DEFAULT_BUCKET
2+ from test import runTest
3+
4+
5+ def test_training_s3 ():
6+ minio = getMinio ("global" )
7+ dda = getDDA (
8+ minio = minio ,
9+ stream_logs = True ,
10+ )
11+ print (dda )
12+
13+ # fp32 model is obviously bigger
14+ result = runTest (
15+ "dreambooth" ,
16+ {"test_url" : dda .url },
17+ {
18+ "MODEL_ID" : "stabilityai/stable-diffusion-2-1-base" ,
19+ "MODEL_REVISION" : "" ,
20+ "MODEL_PRECISION" : "" ,
21+ "MODEL_URL" : "s3://" ,
22+ "train" : "dreambooth" ,
23+ "dest_url" : f"s3:///{ AWS_S3_DEFAULT_BUCKET } /model.tar.zst" ,
24+ },
25+ {"max_train_steps" : 1 },
26+ )
27+
28+ dda .stop ()
29+ minio .stop ()
30+ timings = result ["$timings" ]
31+ assert timings ["training" ] > 0
32+ assert timings ["upload" ] > 0
33+
34+
35+ def test_inference ():
36+ dda = getDDA (
37+ stream_logs = True ,
38+ root_cache = False ,
39+ )
40+ print (dda )
41+
42+ # fp32 model is obviously bigger
43+ result = runTest (
44+ "txt2img" ,
45+ {"test_url" : dda .url },
46+ {
47+ "MODEL_ID" : "model" ,
48+ "MODEL_PRECISION" : "fp16" ,
49+ "MODEL_URL" : f"s3:///{ AWS_S3_DEFAULT_BUCKET } /model.tar.zst" ,
50+ },
51+ {"num_inference_steps" : 1 },
52+ )
53+
54+ dda .stop ()
55+ assert result ["image_base64" ]
You can’t perform that action at this time.
0 commit comments