66use SergiX44 \Gradio \Client \Endpoint ;
77use SergiX44 \Gradio \Client \RemoteClient ;
88use SergiX44 \Gradio \DTO \Config ;
9+ use SergiX44 \Gradio \DTO \Messages \Estimation ;
10+ use SergiX44 \Gradio \DTO \Messages \Message ;
11+ use SergiX44 \Gradio \DTO \Messages \ProcessCompleted ;
12+ use SergiX44 \Gradio \DTO \Messages \ProcessGenerating ;
13+ use SergiX44 \Gradio \DTO \Messages \ProcessStarts ;
14+ use SergiX44 \Gradio \DTO \Messages \QueueFull ;
15+ use SergiX44 \Gradio \DTO \Messages \SendData ;
16+ use SergiX44 \Gradio \DTO \Messages \SendHash ;
917use SergiX44 \Gradio \DTO \Output ;
10- use SergiX44 \Gradio \DTO \Websocket \Estimation ;
11- use SergiX44 \Gradio \DTO \Websocket \Message ;
12- use SergiX44 \Gradio \DTO \Websocket \ProcessCompleted ;
13- use SergiX44 \Gradio \DTO \Websocket \ProcessGenerating ;
14- use SergiX44 \Gradio \DTO \Websocket \ProcessStarts ;
15- use SergiX44 \Gradio \DTO \Websocket \QueueFull ;
16- use SergiX44 \Gradio \DTO \Websocket \SendData ;
17- use SergiX44 \Gradio \DTO \Websocket \SendHash ;
1818use SergiX44 \Gradio \Event \Event ;
1919use SergiX44 \Gradio \Exception \GradioException ;
2020use SergiX44 \Gradio \Exception \QueueFullException ;
@@ -23,7 +23,9 @@ class Client extends RemoteClient
2323{
2424 private const HTTP_PREDICT = 'run/predict ' ;
2525
26- private const WS_PREDICT = 'queue/join ' ;
26+ private const QUEUE_JOIN = 'queue/join ' ;
27+
28+ private const SSE_GET_DATA = 'queue/data ' ;
2729
2830 private const HTTP_CONFIG = 'config ' ;
2931
@@ -38,26 +40,19 @@ class Client extends RemoteClient
3840 public function __construct (string $ src , string $ hfToken = null , Config $ config = null )
3941 {
4042 parent ::__construct ($ src );
41- $ this ->config = $ config ?? $ this ->get ( self ::HTTP_CONFIG , dto: Config::class);
43+ $ this ->config = $ config ?? $ this ->http ( ' get ' , self ::HTTP_CONFIG , dto: Config::class);
4244 $ this ->loadEndpoints ($ this ->config ->dependencies );
4345 $ this ->sessionHash = substr (md5 (microtime ()), 0 , 11 );
4446 $ this ->hfToken = $ hfToken ;
4547 }
4648
4749 protected function loadEndpoints (array $ dependencies ): void
4850 {
49- foreach ($ dependencies as $ index => $ dep ) {
50- $ endpoint = new Endpoint (
51- $ this ,
52- $ index ,
53- ! empty ($ dep ['api_name ' ]) ? $ dep ['api_name ' ] : null ,
54- $ dep ['queue ' ] !== false ,
55- count ($ dep ['inputs ' ])
56- );
57-
51+ foreach ($ dependencies as $ index => $ dp ) {
52+ $ endpoint = new Endpoint ($ this ->config , $ index , $ dp );
5853 $ this ->endpoints [$ index ] = $ endpoint ;
59- if ($ endpoint ->apiName !== null ) {
60- $ this ->endpoints [$ endpoint ->apiName ] = $ endpoint ;
54+ if ($ endpoint ->apiName () !== null ) {
55+ $ this ->endpoints [$ endpoint ->apiName () ] = $ endpoint ;
6156 }
6257 }
6358 }
@@ -83,16 +78,24 @@ public function predict(array $arguments, string $apiName = null, int $fnIndex =
8378 return $ this ->submit ($ endpoint , $ arguments );
8479 }
8580
86- private function submit (Endpoint $ endpoint , array $ arguments ): ?Output
81+ public function submit (Endpoint $ endpoint , array $ arguments ): ?Output
8782 {
8883 $ payload = $ this ->preparePayload ($ arguments );
8984 $ this ->fireEvent (Event::SUBMIT , $ payload );
9085
91- if ($ endpoint ->useWebsockets ) {
92- return $ this ->websocketLoop ($ endpoint , $ payload );
86+ if ($ endpoint ->skipsQueue ()) {
87+ return $ this ->http ('post ' , $ this ->makeUri ($ endpoint ), [
88+ 'data ' => $ payload ,
89+ 'fn_index ' => $ endpoint ->index ,
90+ 'session_hash ' => $ this ->sessionHash ,
91+ 'event_data ' => null ,
92+ ], dto: Output::class);
9393 }
9494
95- return $ this ->post (self ::HTTP_PREDICT , ['data ' => $ payload ], Output::class);
95+ return match ($ this ->config ->protocol ) {
96+ 'sse_v1 ' , 'sse_v2 ' => $ this ->sseV1V2Loop ($ endpoint , $ payload ),
97+ default => $ this ->websocketLoop ($ endpoint , $ payload ),
98+ };
9699 }
97100
98101 private function preparePayload (array $ arguments ): array
@@ -124,16 +127,26 @@ private function preparePayload(array $arguments): array
124127 }, $ arguments );
125128 }
126129
130+ protected function makeUri (Endpoint $ endpoint ): string
131+ {
132+ $ name = $ endpoint ->apiName ();
133+ if ($ name !== null ) {
134+ $ name = str_replace ('/ ' , '' , $ name );
135+ return "run/ $ name " ;
136+ }
137+
138+ return self ::HTTP_PREDICT ;
139+ }
140+
127141 /**
128142 * @throws GradioException
129143 * @throws QueueFullException
130144 * @throws \JsonException
131145 */
132146 private function websocketLoop (Endpoint $ endpoint , array $ payload ): ?Output
133147 {
134- $ ws = $ this ->ws (self ::WS_PREDICT );
148+ $ ws = $ this ->ws (self ::QUEUE_JOIN );
135149
136- $ message = null ;
137150 while (true ) {
138151 $ data = $ ws ->receive ();
139152
@@ -183,4 +196,68 @@ private function websocketLoop(Endpoint $endpoint, array $payload): ?Output
183196
184197 return $ message ?->output;
185198 }
199+
200+ private function sseV1V2Loop (Endpoint $ endpoint , array $ payload ): ?Output
201+ {
202+ $ response = $ this ->httpRaw ('post ' , self ::QUEUE_JOIN , [
203+ 'data ' => $ payload ,
204+ 'fn_index ' => $ endpoint ->index ,
205+ 'session_hash ' => $ this ->sessionHash ,
206+ ]);
207+
208+ if ($ response ->getStatusCode () === 503 ) {
209+ throw new QueueFullException ();
210+ }
211+
212+ if ($ response ->getStatusCode () !== 200 ) {
213+ throw new GradioException ('Error joining the queue ' );
214+ }
215+
216+ // $data = $this->decodeResponse($response);
217+ // $eventId = $data['event_id'];
218+
219+ $ response = $ this ->httpRaw ('get ' , self ::SSE_GET_DATA , ['session_hash ' => $ this ->sessionHash ], [
220+ 'headers ' => [
221+ 'Accept ' => 'text/event-stream ' ,
222+ ],
223+ 'stream ' => true ,
224+ ]);
225+
226+ $ buffer = '' ;
227+ $ message = null ;
228+ while (!$ response ->getBody ()->eof ()) {
229+ $ data = $ response ->getBody ()->read (1 );
230+ if ($ data !== "\n" ) {
231+ $ buffer .= $ data ;
232+ continue ;
233+ }
234+
235+ // read second \n
236+ $ response ->getBody ()->read (1 );
237+
238+ // remove data:
239+ $ buffer = str_replace ('data: ' , '' , $ buffer );
240+ $ message = $ this ->hydrator ->hydrateWithJson (Message::class, $ buffer );
241+
242+ if ($ message instanceof ProcessCompleted) {
243+ $ this ->fireEvent (Event::PROCESS_COMPLETED , [$ message ]);
244+ if ($ message ->success ) {
245+ $ this ->fireEvent (Event::PROCESS_SUCCESS , [$ message ]);
246+ } else {
247+ $ this ->fireEvent (Event::PROCESS_FAILED , [$ message ]);
248+ }
249+ break ;
250+ } elseif ($ message instanceof ProcessStarts) {
251+ $ this ->fireEvent (Event::PROCESS_STARTS , [$ message ]);
252+ } elseif ($ message instanceof ProcessGenerating) {
253+ $ this ->fireEvent (Event::PROCESS_GENERATING , [$ message ]);
254+ } elseif ($ message instanceof Estimation) {
255+ $ this ->fireEvent (Event::QUEUE_ESTIMATION , [$ message ]);
256+ }
257+
258+ $ buffer = '' ;
259+ }
260+
261+ return $ message ?->output;
262+ }
186263}
0 commit comments