@@ -57,9 +57,9 @@ def predict_wrapper(
5757
5858 def generate_wrapper (
5959 self , request : PostModelOutputsRequest ) -> Iterator [service_pb2 .MultiOutputResponse ]:
60- list_dict_input , inference_params = self .parse_input_request (request )
6160 if self .download_request_urls :
6261 ensure_urls_downloaded (request )
62+ list_dict_input , inference_params = self .parse_input_request (request )
6363 outputs = self .generate (list_dict_input , inference_parameters = inference_params )
6464 for output in outputs :
6565 yield self .convert_output_to_proto (output )
@@ -71,13 +71,13 @@ def _preprocess_stream(
7171 input_data , _ = self .parse_input_request (req )
7272 yield input_data
7373
74- def stream_wrapper (self , request : Iterator [PostModelOutputsRequest ]
74+ def stream_wrapper (self , request_iterator : Iterator [PostModelOutputsRequest ]
7575 ) -> Iterator [service_pb2 .MultiOutputResponse ]:
76- first_request = next (request )
77- _ , inference_params = self .parse_input_request (first_request )
78- request_iterator = itertools .chain ([first_request ], request )
7976 if self .download_request_urls :
8077 request_iterator = readahead (map (ensure_urls_downloaded , request_iterator ))
78+ first_request = next (request_iterator )
79+ _ , inference_params = self .parse_input_request (first_request )
80+ request_iterator = itertools .chain ([first_request ], request_iterator )
8181 outputs = self .stream (self ._preprocess_stream (request_iterator ), inference_params )
8282 for output in outputs :
8383 yield self .convert_output_to_proto (output )
0 commit comments