@@ -25,7 +25,7 @@ class Client extends RemoteClient
2525
2626 private const QUEUE_JOIN = 'queue/join ' ;
2727
28- private const SSE_GET_DATA = 'queue/data ' ;
28+ private const SSE_QUEUE_DATA = 'queue/data ' ;
2929
3030 private const HTTP_CONFIG = 'config ' ;
3131
@@ -62,7 +62,7 @@ public function getConfig(): Config
6262 return $ this ->config ;
6363 }
6464
65- public function predict (array $ arguments , ?string $ apiName = null , ?int $ fnIndex = null , bool $ raw = false ): Output |array |null
65+ public function predict (array $ arguments , ?string $ apiName = null , ?int $ fnIndex = null , bool $ raw = false , ? int $ triggerId = null ): Output |array |null
6666 {
6767 if ($ apiName === null && $ fnIndex === null ) {
6868 throw new InvalidArgumentException ('You must provide an apiName or fnIndex ' );
@@ -75,10 +75,10 @@ public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex
7575 throw new InvalidArgumentException ('Endpoint not found ' );
7676 }
7777
78- return $ this ->submit ($ endpoint , $ arguments , $ raw );
78+ return $ this ->submit ($ endpoint , $ arguments , $ raw, $ triggerId );
7979 }
8080
81- protected function submit (Endpoint $ endpoint , array $ arguments , bool $ raw ): Output |array |null
81+ protected function submit (Endpoint $ endpoint , array $ arguments , bool $ raw, ? int $ triggerId = null ): Output |array |null
8282 {
8383 $ payload = $ this ->preparePayload ($ arguments );
8484 $ this ->fireEvent (Event::SUBMIT , $ payload );
@@ -88,12 +88,15 @@ protected function submit(Endpoint $endpoint, array $arguments, bool $raw): Outp
8888 'data ' => $ payload ,
8989 'fn_index ' => $ endpoint ->index ,
9090 'session_hash ' => $ this ->sessionHash ,
91+ 'trigger_id ' => $ triggerId ,
92+ 'event_data ' => null ,
9193 ], dto: $ raw ? null : Output::class);
9294 }
9395
9496 return match ($ this ->config ->protocol ) {
95- 'sse_v1 ' , 'sse_v2 ' => $ this ->sseV1V2Loop ($ endpoint , $ payload ),
96- default => $ this ->websocketLoop ($ endpoint , $ payload ),
97+ 'sse ' , 'sse_v1 ' , 'sse_v2 ' , 'sse_v2.1 ' , 'sse_v3 ' => $ this ->sseLoop ($ endpoint , $ payload , $ this ->config ->protocol , $ triggerId ),
98+ 'ws ' => $ this ->websocketLoop ($ endpoint , $ payload ),
99+ default => throw new GradioException ('Unknown protocol ' .$ this ->config ->protocol ),
97100 };
98101 }
99102
@@ -185,26 +188,34 @@ private function websocketLoop(Endpoint $endpoint, array $payload): ?Output
185188 return $ message ?->output;
186189 }
187190
188- private function sseV1V2Loop (Endpoint $ endpoint , array $ payload ): ?Output
191+ private function sseLoop (Endpoint $ endpoint , array $ payload, string $ protocol , ? int $ triggerId ): ?Output
189192 {
190- $ response = $ this ->httpRaw ('post ' , self ::QUEUE_JOIN , [
191- 'data ' => $ payload ,
192- 'fn_index ' => $ endpoint ->index ,
193- 'session_hash ' => $ this ->sessionHash ,
194- ]);
193+ if ($ protocol === 'sse ' ) {
194+ $ getEndpoint = self ::QUEUE_JOIN ;
195+ } else {
196+ $ getEndpoint = self ::SSE_QUEUE_DATA ;
197+ $ response = $ this ->httpRaw ('post ' , self ::QUEUE_JOIN , [
198+ 'data ' => $ payload ,
199+ 'fn_index ' => $ endpoint ->index ,
200+ 'session_hash ' => $ this ->sessionHash ,
201+ ]);
195202
196- if ($ response ->getStatusCode () === 503 ) {
197- throw new QueueFullException ();
198- }
199203
200- if ($ response ->getStatusCode () !== 200 ) {
201- throw new GradioException ('Error joining the queue ' );
204+ if ($ response ->getStatusCode () === 503 ) {
205+ throw new QueueFullException ();
206+ }
207+
208+ if ($ response ->getStatusCode () !== 200 ) {
209+ throw new GradioException ('Error joining the queue ' );
210+ }
202211 }
203212
204- // $data = $this->decodeResponse($response);
205- // $eventId = $data['event_id'];
213+ $ params = ['session_hash ' => $ this ->sessionHash ];
214+ if ($ protocol === 'sse ' ) {
215+ $ params ['fn_index ' ] = $ endpoint ->index ;
216+ }
206217
207- $ response = $ this ->httpRaw ('get ' , self :: SSE_GET_DATA , [ ' session_hash ' => $ this -> sessionHash ] , [
218+ $ response = $ this ->httpRaw ('get ' , $ getEndpoint , $ params , [
208219 'headers ' => [
209220 'Accept ' => 'text/event-stream ' ,
210221 ],
@@ -213,7 +224,7 @@ private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output
213224
214225 $ buffer = '' ;
215226 $ message = null ;
216- while (! $ response ->getBody ()->eof ()) {
227+ while (!$ response ->getBody ()->eof ()) {
217228 $ data = $ response ->getBody ()->read (1 );
218229 if ($ data !== "\n" ) {
219230 $ buffer .= $ data ;
@@ -228,7 +239,27 @@ private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output
228239 $ buffer = str_replace ('data: ' , '' , $ buffer );
229240 $ message = $ this ->hydrator ->hydrateWithJson (Message::class, $ buffer );
230241
242+ if ($ message instanceof SendData && $ protocol === 'sse ' ) {
243+ $ sendData = $ this ->httpRaw ('post ' , self ::SSE_QUEUE_DATA , [
244+ 'data ' => $ payload ,
245+ 'fn_index ' => $ endpoint ->index ,
246+ 'session_hash ' => $ this ->sessionHash ,
247+ 'event_id ' => $ message ->event_id ,
248+ 'event_data ' => $ message ?->event_data,
249+ 'trigger_id ' => $ triggerId ,
250+ ]);
251+ if ($ sendData ->getStatusCode () !== 200 ) {
252+ throw new GradioException ('Error sending data ' );
253+ }
254+ $ buffer = '' ;
255+ continue ;
256+ }
257+
231258 if ($ message instanceof ProcessCompleted) {
259+ if (in_array ($ protocol , ['sse_v2 ' , 'sse_v2.1 ' ], true )) {
260+ $ response ->getBody ()->close ();
261+ }
262+
232263 $ this ->fireEvent (Event::PROCESS_COMPLETED , [$ message ]);
233264 if ($ message ->success ) {
234265 $ this ->fireEvent (Event::PROCESS_SUCCESS , [$ message ]);
0 commit comments