Skip to content

Commit 0f7b0fc

Browse files
committed
support for sse v1 v2
1 parent b232f51 commit 0f7b0fc

19 files changed

Lines changed: 279 additions & 99 deletions

composer.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"require": {
1818
"php": "^8.2",
1919
"guzzlehttp/guzzle": "^7.7",
20-
"nutgram/hydrator": ">=5.0",
20+
"nutgram/hydrator": ">=6.0",
2121
"phrity/websocket": "^1.7.2",
2222
"ext-fileinfo": "*"
2323
},

src/Client.php

Lines changed: 104 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
use SergiX44\Gradio\Client\Endpoint;
77
use SergiX44\Gradio\Client\RemoteClient;
88
use SergiX44\Gradio\DTO\Config;
9+
use SergiX44\Gradio\DTO\Messages\Estimation;
10+
use SergiX44\Gradio\DTO\Messages\Message;
11+
use SergiX44\Gradio\DTO\Messages\ProcessCompleted;
12+
use SergiX44\Gradio\DTO\Messages\ProcessGenerating;
13+
use SergiX44\Gradio\DTO\Messages\ProcessStarts;
14+
use SergiX44\Gradio\DTO\Messages\QueueFull;
15+
use SergiX44\Gradio\DTO\Messages\SendData;
16+
use SergiX44\Gradio\DTO\Messages\SendHash;
917
use SergiX44\Gradio\DTO\Output;
10-
use SergiX44\Gradio\DTO\Websocket\Estimation;
11-
use SergiX44\Gradio\DTO\Websocket\Message;
12-
use SergiX44\Gradio\DTO\Websocket\ProcessCompleted;
13-
use SergiX44\Gradio\DTO\Websocket\ProcessGenerating;
14-
use SergiX44\Gradio\DTO\Websocket\ProcessStarts;
15-
use SergiX44\Gradio\DTO\Websocket\QueueFull;
16-
use SergiX44\Gradio\DTO\Websocket\SendData;
17-
use SergiX44\Gradio\DTO\Websocket\SendHash;
1818
use SergiX44\Gradio\Event\Event;
1919
use SergiX44\Gradio\Exception\GradioException;
2020
use SergiX44\Gradio\Exception\QueueFullException;
@@ -23,7 +23,9 @@ class Client extends RemoteClient
2323
{
2424
private const HTTP_PREDICT = 'run/predict';
2525

26-
private const WS_PREDICT = 'queue/join';
26+
private const QUEUE_JOIN = 'queue/join';
27+
28+
private const SSE_GET_DATA = 'queue/data';
2729

2830
private const HTTP_CONFIG = 'config';
2931

@@ -38,26 +40,19 @@ class Client extends RemoteClient
3840
public function __construct(string $src, string $hfToken = null, Config $config = null)
3941
{
4042
parent::__construct($src);
41-
$this->config = $config ?? $this->get(self::HTTP_CONFIG, dto: Config::class);
43+
$this->config = $config ?? $this->http('get', self::HTTP_CONFIG, dto: Config::class);
4244
$this->loadEndpoints($this->config->dependencies);
4345
$this->sessionHash = substr(md5(microtime()), 0, 11);
4446
$this->hfToken = $hfToken;
4547
}
4648

4749
protected function loadEndpoints(array $dependencies): void
4850
{
49-
foreach ($dependencies as $index => $dep) {
50-
$endpoint = new Endpoint(
51-
$this,
52-
$index,
53-
! empty($dep['api_name']) ? $dep['api_name'] : null,
54-
$dep['queue'] !== false,
55-
count($dep['inputs'])
56-
);
57-
51+
foreach ($dependencies as $index => $dp) {
52+
$endpoint = new Endpoint($this->config, $index, $dp);
5853
$this->endpoints[$index] = $endpoint;
59-
if ($endpoint->apiName !== null) {
60-
$this->endpoints[$endpoint->apiName] = $endpoint;
54+
if ($endpoint->apiName() !== null) {
55+
$this->endpoints[$endpoint->apiName()] = $endpoint;
6156
}
6257
}
6358
}
@@ -83,16 +78,24 @@ public function predict(array $arguments, string $apiName = null, int $fnIndex =
8378
return $this->submit($endpoint, $arguments);
8479
}
8580

86-
private function submit(Endpoint $endpoint, array $arguments): ?Output
81+
public function submit(Endpoint $endpoint, array $arguments): ?Output
8782
{
8883
$payload = $this->preparePayload($arguments);
8984
$this->fireEvent(Event::SUBMIT, $payload);
9085

91-
if ($endpoint->useWebsockets) {
92-
return $this->websocketLoop($endpoint, $payload);
86+
if ($endpoint->skipsQueue()) {
87+
return $this->http('post', $this->makeUri($endpoint), [
88+
'data' => $payload,
89+
'fn_index' => $endpoint->index,
90+
'session_hash' => $this->sessionHash,
91+
'event_data' => null,
92+
], dto: Output::class);
9393
}
9494

95-
return $this->post(self::HTTP_PREDICT, ['data' => $payload], Output::class);
95+
return match ($this->config->protocol) {
96+
'sse_v1', 'sse_v2' => $this->sseV1V2Loop($endpoint, $payload),
97+
default => $this->websocketLoop($endpoint, $payload),
98+
};
9699
}
97100

98101
private function preparePayload(array $arguments): array
@@ -124,16 +127,26 @@ private function preparePayload(array $arguments): array
124127
}, $arguments);
125128
}
126129

130+
protected function makeUri(Endpoint $endpoint): string
131+
{
132+
$name = $endpoint->apiName();
133+
if ($name !== null) {
134+
$name = str_replace('/', '', $name);
135+
return "run/$name";
136+
}
137+
138+
return self::HTTP_PREDICT;
139+
}
140+
127141
/**
128142
* @throws GradioException
129143
* @throws QueueFullException
130144
* @throws \JsonException
131145
*/
132146
private function websocketLoop(Endpoint $endpoint, array $payload): ?Output
133147
{
134-
$ws = $this->ws(self::WS_PREDICT);
148+
$ws = $this->ws(self::QUEUE_JOIN);
135149

136-
$message = null;
137150
while (true) {
138151
$data = $ws->receive();
139152

@@ -183,4 +196,68 @@ private function websocketLoop(Endpoint $endpoint, array $payload): ?Output
183196

184197
return $message?->output;
185198
}
199+
200+
private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output
201+
{
202+
$response = $this->httpRaw('post', self::QUEUE_JOIN, [
203+
'data' => $payload,
204+
'fn_index' => $endpoint->index,
205+
'session_hash' => $this->sessionHash,
206+
]);
207+
208+
if ($response->getStatusCode() === 503) {
209+
throw new QueueFullException();
210+
}
211+
212+
if ($response->getStatusCode() !== 200) {
213+
throw new GradioException('Error joining the queue');
214+
}
215+
216+
// $data = $this->decodeResponse($response);
217+
// $eventId = $data['event_id'];
218+
219+
$response = $this->httpRaw('get', self::SSE_GET_DATA, ['session_hash' => $this->sessionHash], [
220+
'headers' => [
221+
'Accept' => 'text/event-stream',
222+
],
223+
'stream' => true,
224+
]);
225+
226+
$buffer = '';
227+
$message = null;
228+
while (!$response->getBody()->eof()) {
229+
$data = $response->getBody()->read(1);
230+
if ($data !== "\n") {
231+
$buffer .= $data;
232+
continue;
233+
}
234+
235+
// read second \n
236+
$response->getBody()->read(1);
237+
238+
// remove data:
239+
$buffer = str_replace('data: ', '', $buffer);
240+
$message = $this->hydrator->hydrateWithJson(Message::class, $buffer);
241+
242+
if ($message instanceof ProcessCompleted) {
243+
$this->fireEvent(Event::PROCESS_COMPLETED, [$message]);
244+
if ($message->success) {
245+
$this->fireEvent(Event::PROCESS_SUCCESS, [$message]);
246+
} else {
247+
$this->fireEvent(Event::PROCESS_FAILED, [$message]);
248+
}
249+
break;
250+
} elseif ($message instanceof ProcessStarts) {
251+
$this->fireEvent(Event::PROCESS_STARTS, [$message]);
252+
} elseif ($message instanceof ProcessGenerating) {
253+
$this->fireEvent(Event::PROCESS_GENERATING, [$message]);
254+
} elseif ($message instanceof Estimation) {
255+
$this->fireEvent(Event::QUEUE_ESTIMATION, [$message]);
256+
}
257+
258+
$buffer = '';
259+
}
260+
261+
return $message?->output;
262+
}
186263
}

src/Client/Endpoint.php

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,35 @@
22

33
namespace SergiX44\Gradio\Client;
44

5-
use SergiX44\Gradio\Client;
5+
use SergiX44\Gradio\DTO\Config;
66

77
readonly class Endpoint
88
{
9+
910
public function __construct(
10-
public Client $client,
11+
private Config $config,
1112
public int $index,
12-
public ?string $apiName,
13-
public bool $useWebsockets,
14-
public int $argsCount = 1,
13+
private readonly array $data
1514
) {
1615
}
16+
17+
public function __get(string $name): mixed
18+
{
19+
return $this->data[$name] ?? null;
20+
}
21+
22+
public function __isset(string $name): bool
23+
{
24+
return isset($this->data[$name]);
25+
}
26+
27+
public function skipsQueue(): bool
28+
{
29+
return !($this->data['queue'] ?? $this->config->enable_queue);
30+
}
31+
32+
public function apiName(): ?string
33+
{
34+
return !empty($this->data['api_name']) ? $this->data['api_name'] : null;
35+
}
1736
}

src/Client/RemoteClient.php

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ abstract class RemoteClient extends RegisterEvents
2020
public function __construct(string $src)
2121
{
2222
if (
23-
! str_starts_with($src, 'http://') &&
24-
! str_starts_with($src, 'https://') &&
25-
! str_starts_with($src, 'ws://') &&
26-
! str_starts_with($src, 'wss://')
23+
!str_starts_with($src, 'http://') &&
24+
!str_starts_with($src, 'https://') &&
25+
!str_starts_with($src, 'ws://') &&
26+
!str_starts_with($src, 'wss://')
2727
) {
2828
throw new InvalidArgumentException('The src must not contain the protocol');
2929
}
@@ -36,37 +36,38 @@ public function __construct(string $src)
3636
'base_uri' => str_replace('ws', 'http', $this->src),
3737
'headers' => [
3838
'User-Agent' => 'gradio_client_php/1.0',
39+
'Accept' => 'application/json',
3940
],
4041
]);
4142
}
4243

43-
protected function get(string $uri, array $params = [], string $dto = null)
44+
protected function http(string $method, string $uri, array $params = [], array $opt = [], ?string $dto = null)
4445
{
45-
$response = $this->httpClient->get($uri, ['query' => $params]);
46-
47-
return $this->parseResponse($response, $dto);
46+
$response = $this->httpRaw($method, $uri, $params, $opt);
47+
return $this->decodeResponse($response, $dto);
4848
}
4949

50-
protected function post(string $uri, array $params = [], string $dto = null)
50+
protected function httpRaw(string $method, string $uri, array $params = [], array $opt = [])
5151
{
52-
$response = $this->httpClient->post($uri, ['json' => $params]);
53-
54-
return $this->parseResponse($response, $dto);
52+
$keyContent = $method === 'get' ? 'query' : 'json';
53+
return $this->httpClient->request($method, $uri, array_merge([
54+
$keyContent => $params,
55+
], $opt));
5556
}
5657

5758
protected function ws(string $uri, array $options = []): EnhancedClient
5859
{
5960
return new EnhancedClient(str_replace('http', 'ws', $this->src).$uri, $options);
6061
}
6162

62-
private function parseResponse(ResponseInterface $response, string $mapTo = null): mixed
63+
protected function decodeResponse(ResponseInterface|string $response, string $mapTo = null): mixed
6364
{
64-
$body = $response->getBody()->getContents();
65+
$body = $response instanceof ResponseInterface ? $response->getBody()->getContents() : $response;
6566

6667
if ($mapTo !== null) {
6768
return $this->hydrator->hydrateWithJson($mapTo, $body);
6869
}
6970

70-
return json_decode($body, flags: JSON_THROW_ON_ERROR);
71+
return json_decode($body, true, flags: JSON_THROW_ON_ERROR);
7172
}
7273
}

src/DTO/Config.php

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,23 @@ class Config
3333
public array $dependencies = [];
3434

3535
public ?string $root = null;
36+
37+
public ?string $protocol = null;
38+
39+
private array $_extra = [];
40+
41+
public function __set(string $name, $value): void
42+
{
43+
$this->_extra[$name] = $value;
44+
}
45+
46+
public function __get(string $name)
47+
{
48+
return $this->_extra[$name] ?? null;
49+
}
50+
51+
public function __isset(string $name): bool
52+
{
53+
return isset($this->_extra[$name]);
54+
}
3655
}

src/DTO/Messages/Estimation.php

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
<?php
2+
3+
namespace SergiX44\Gradio\DTO\Messages;
4+
5+
class Estimation extends Message
6+
{
7+
public ?int $rank = null;
8+
9+
public ?int $queue_size = null;
10+
11+
public ?float $avg_event_process_time = null;
12+
13+
public ?float $avg_event_concurrent_process_time = null;
14+
15+
public ?float $rank_eta = null;
16+
17+
public ?float $queue_eta = null;
18+
}

src/DTO/Messages/Log.php

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<?php
2+
3+
namespace SergiX44\Gradio\DTO\Messages;
4+
5+
class Log extends Message
6+
{
7+
public ?string $log = null;
8+
9+
public ?string $level = null;
10+
11+
public ?string $event_id = null;
12+
}

src/DTO/Messages/Message.php

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<?php
2+
3+
namespace SergiX44\Gradio\DTO\Messages;
4+
5+
use SergiX44\Gradio\DTO\Resolvers\MessageResolver;
6+
use SergiX44\Gradio\DTO\Resolvers\MessageType;
7+
use SergiX44\Hydrator\Resolver\EnumOrScalar;
8+
9+
#[MessageResolver]
10+
abstract class Message
11+
{
12+
#[EnumOrScalar]
13+
public MessageType|string $msg;
14+
15+
private array $_extra = [];
16+
17+
public function __set(string $name, $value): void
18+
{
19+
$this->_extra[$name] = $value;
20+
}
21+
22+
public function __get(string $name)
23+
{
24+
return $this->_extra[$name] ?? null;
25+
}
26+
27+
public function __isset(string $name): bool
28+
{
29+
return isset($this->_extra[$name]);
30+
}
31+
}

0 commit comments

Comments
 (0)