44import nucleus
55import os
66
7+ from itertools import zip_longest
8+
79import time
810
911
2123 "API Key to use. Defaults to NUCLEUS_PYTEST_API_KEY environment variable" ,
2224)
2325
26+ flags .DEFINE_integer ("job_parallelism" , 8 , "Amount of concurrent jobs to use." )
27+
2428# Dataset upload flags
2529flags .DEFINE_enum (
2630 "create_or_reuse_dataset" ,
3539)
3640flags .DEFINE_integer (
3741 "num_dataset_items" ,
38- 100000 ,
42+ 10000000 ,
3943 "Number of dataset items to create if creating a dataset" ,
4044 lower_bound = 0 ,
4145)
4246flags .DEFINE_bool (
43- "cleanup_dataset" , True , "Whether to delete the dataset after the test."
47+ "cleanup_dataset" , False , "Whether to delete the dataset after the test."
4448)
4549
4650# Annotation upload flags
5458# Prediction upload flags
5559flags .DEFINE_integer (
5660 "num_predictions_per_dataset_item" ,
57- 0 ,
61+ 1 ,
5862 "Number of annotations per dataset item" ,
5963 lower_bound = 0 ,
6064)
6165
66+ TIMINGS = {}
67+
68+
69+ def chunk (iterable , chunk_size , fillvalue = None ):
70+ "Collect data into fixed-length chunks or blocks"
71+ args = [iter (iterable )] * chunk_size
72+
73+ for chunk_iterable in zip_longest (* args , fillvalue = fillvalue ):
74+ yield filter (lambda x : x is not None , chunk_iterable )
75+
6276
6377def client ():
6478 return nucleus .NucleusClient (api_key = FLAGS .api_key )
@@ -126,15 +140,23 @@ def create_or_get_dataset():
126140 dataset = client ().create_dataset ("Privacy Mode Load Test Dataset" )
127141 print ("Starting dataset item upload" )
128142 tic = time .time ()
129- job = dataset .append (
130- dataset_item_generator (), update = True , asynchronous = True
131- )
132- try :
133- job .sleep_until_complete (False )
134- except JobError :
135- print (job .errors ())
143+ chunk_size = FLAGS .num_dataset_items // FLAGS .job_parallelism
144+ jobs = []
145+ for dataset_item_chunk in chunk (dataset_item_generator (), chunk_size ):
146+ jobs .append (
147+ dataset .append (
148+ dataset_item_chunk , update = True , asynchronous = True
149+ )
150+ )
151+
152+ for job in jobs :
153+ try :
154+ job .sleep_until_complete (False )
155+ except JobError :
156+ print (job .errors ())
136157 toc = time .time ()
137158 print ("Finished dataset item upload: %s" % (toc - tic ))
159+ TIMINGS [f"Dataset Item Upload { FLAGS .num_dataset_items } " ] = toc - tic
138160 else :
139161 print (f"Reusing dataset { FLAGS .dataset_id } " )
140162 dataset = client ().get_dataset (FLAGS .dataset_id )
@@ -144,15 +166,26 @@ def create_or_get_dataset():
144166def upload_annotations (dataset : Dataset ):
145167 print ("Starting annotation upload" )
146168 tic = time .time ()
147- job = dataset .annotate (
148- list (annotation_generator ()), update = False , asynchronous = True
169+ jobs = []
170+ num_annotations = (
171+ FLAGS .num_dataset_items * FLAGS .num_annotations_per_dataset_item
149172 )
150- try :
151- job .sleep_until_complete (False )
152- except JobError :
153- print (job .errors ())
173+ chunk_size = num_annotations // FLAGS .job_parallelism
174+ for annotation_chunk in chunk (annotation_generator (), chunk_size ):
175+ jobs .append (
176+ dataset .annotate (
177+ list (annotation_chunk ), update = False , asynchronous = True
178+ )
179+ )
180+
181+ for job in jobs :
182+ try :
183+ job .sleep_until_complete (False )
184+ except JobError :
185+ print (job .errors ())
154186 toc = time .time ()
155187 print ("Finished annotation upload: %s" % (toc - tic ))
188+ TIMINGS [f"Annotation Upload { num_annotations } " ] = toc - tic
156189
157190
158191def upload_predictions (dataset : Dataset ):
@@ -167,16 +200,24 @@ def upload_predictions(dataset: Dataset):
167200
168201 print ("Starting prediction upload" )
169202
170- job = run . predict (
171- list ( prediction_generator ()), update = True , asynchronous = True
203+ num_predictions = (
204+ FLAGS . num_dataset_items * FLAGS . num_predictions_per_dataset_item
172205 )
206+ chunk_size = num_predictions // FLAGS .job_parallelism
207+ jobs = []
208+ for prediction_chunk in chunk (prediction_generator (), chunk_size ):
209+ jobs .append (
210+ run .predict (list (prediction_chunk ), update = True , asynchronous = True )
211+ )
173212
174- try :
175- job .sleep_until_complete (False )
176- except JobError :
177- print (job .errors ())
213+ for job in jobs :
214+ try :
215+ job .sleep_until_complete (False )
216+ except JobError :
217+ print (job .errors ())
178218 toc = time .time ()
179219 print ("Finished prediction upload: %s" % (toc - tic ))
220+ TIMINGS [f"Prediction Upload { num_predictions } " ] = toc - tic
180221
181222
182223def main (unused_argv ):
@@ -194,6 +235,8 @@ def main(unused_argv):
194235 if FLAGS .cleanup_dataset and FLAGS .create_or_reuse_dataset == "create" :
195236 client ().delete_dataset (dataset .id )
196237
238+ print (TIMINGS )
239+
197240
198241if __name__ == "__main__" :
199242 app .run (main )
0 commit comments