Skip to content

Commit 8262fd3

Browse files
committed
Improve sse protocol support
1 parent c4118c0 commit 8262fd3

4 files changed

Lines changed: 78 additions & 36 deletions

File tree

src/Client.php

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class Client extends RemoteClient
2525

2626
private const QUEUE_JOIN = 'queue/join';
2727

28-
private const SSE_GET_DATA = 'queue/data';
28+
private const SSE_QUEUE_DATA = 'queue/data';
2929

3030
private const HTTP_CONFIG = 'config';
3131

@@ -62,7 +62,7 @@ public function getConfig(): Config
6262
return $this->config;
6363
}
6464

65-
public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex = null, bool $raw = false): Output|array|null
65+
public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex = null, bool $raw = false, ?int $triggerId = null): Output|array|null
6666
{
6767
if ($apiName === null && $fnIndex === null) {
6868
throw new InvalidArgumentException('You must provide an apiName or fnIndex');
@@ -75,10 +75,10 @@ public function predict(array $arguments, ?string $apiName = null, ?int $fnIndex
7575
throw new InvalidArgumentException('Endpoint not found');
7676
}
7777

78-
return $this->submit($endpoint, $arguments, $raw);
78+
return $this->submit($endpoint, $arguments, $raw, $triggerId);
7979
}
8080

81-
protected function submit(Endpoint $endpoint, array $arguments, bool $raw): Output|array|null
81+
protected function submit(Endpoint $endpoint, array $arguments, bool $raw, ?int $triggerId = null): Output|array|null
8282
{
8383
$payload = $this->preparePayload($arguments);
8484
$this->fireEvent(Event::SUBMIT, $payload);
@@ -88,12 +88,15 @@ protected function submit(Endpoint $endpoint, array $arguments, bool $raw): Outp
8888
'data' => $payload,
8989
'fn_index' => $endpoint->index,
9090
'session_hash' => $this->sessionHash,
91+
'trigger_id' => $triggerId,
92+
'event_data' => null,
9193
], dto: $raw ? null : Output::class);
9294
}
9395

9496
return match ($this->config->protocol) {
95-
'sse_v1', 'sse_v2' => $this->sseV1V2Loop($endpoint, $payload),
96-
default => $this->websocketLoop($endpoint, $payload),
97+
'sse', 'sse_v1', 'sse_v2', 'sse_v2.1', 'sse_v3' => $this->sseLoop($endpoint, $payload, $this->config->protocol, $triggerId),
98+
'ws' => $this->websocketLoop($endpoint, $payload),
99+
default => throw new GradioException('Unknown protocol '.$this->config->protocol),
97100
};
98101
}
99102

@@ -185,26 +188,34 @@ private function websocketLoop(Endpoint $endpoint, array $payload): ?Output
185188
return $message?->output;
186189
}
187190

188-
private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output
191+
private function sseLoop(Endpoint $endpoint, array $payload, string $protocol, ?int $triggerId): ?Output
189192
{
190-
$response = $this->httpRaw('post', self::QUEUE_JOIN, [
191-
'data' => $payload,
192-
'fn_index' => $endpoint->index,
193-
'session_hash' => $this->sessionHash,
194-
]);
193+
if ($protocol === 'sse') {
194+
$getEndpoint = self::QUEUE_JOIN;
195+
} else {
196+
$getEndpoint = self::SSE_QUEUE_DATA;
197+
$response = $this->httpRaw('post', self::QUEUE_JOIN, [
198+
'data' => $payload,
199+
'fn_index' => $endpoint->index,
200+
'session_hash' => $this->sessionHash,
201+
]);
195202

196-
if ($response->getStatusCode() === 503) {
197-
throw new QueueFullException();
198-
}
199203

200-
if ($response->getStatusCode() !== 200) {
201-
throw new GradioException('Error joining the queue');
204+
if ($response->getStatusCode() === 503) {
205+
throw new QueueFullException();
206+
}
207+
208+
if ($response->getStatusCode() !== 200) {
209+
throw new GradioException('Error joining the queue');
210+
}
202211
}
203212

204-
// $data = $this->decodeResponse($response);
205-
// $eventId = $data['event_id'];
213+
$params = ['session_hash' => $this->sessionHash];
214+
if ($protocol === 'sse') {
215+
$params['fn_index'] = $endpoint->index;
216+
}
206217

207-
$response = $this->httpRaw('get', self::SSE_GET_DATA, ['session_hash' => $this->sessionHash], [
218+
$response = $this->httpRaw('get', $getEndpoint, $params, [
208219
'headers' => [
209220
'Accept' => 'text/event-stream',
210221
],
@@ -213,7 +224,7 @@ private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output
213224

214225
$buffer = '';
215226
$message = null;
216-
while (! $response->getBody()->eof()) {
227+
while (!$response->getBody()->eof()) {
217228
$data = $response->getBody()->read(1);
218229
if ($data !== "\n") {
219230
$buffer .= $data;
@@ -228,7 +239,27 @@ private function sseV1V2Loop(Endpoint $endpoint, array $payload): ?Output
228239
$buffer = str_replace('data: ', '', $buffer);
229240
$message = $this->hydrator->hydrateWithJson(Message::class, $buffer);
230241

242+
if ($message instanceof SendData && $protocol === 'sse') {
243+
$sendData = $this->httpRaw('post', self::SSE_QUEUE_DATA, [
244+
'data' => $payload,
245+
'fn_index' => $endpoint->index,
246+
'session_hash' => $this->sessionHash,
247+
'event_id' => $message->event_id,
248+
'event_data' => $message?->event_data,
249+
'trigger_id' => $triggerId,
250+
]);
251+
if ($sendData->getStatusCode() !== 200) {
252+
throw new GradioException('Error sending data');
253+
}
254+
$buffer = '';
255+
continue;
256+
}
257+
231258
if ($message instanceof ProcessCompleted) {
259+
if (in_array($protocol, ['sse_v2', 'sse_v2.1'], true)) {
260+
$response->getBody()->close();
261+
}
262+
232263
$this->fireEvent(Event::PROCESS_COMPLETED, [$message]);
233264
if ($message->success) {
234265
$this->fireEvent(Event::PROCESS_SUCCESS, [$message]);

src/DTO/Messages/SendData.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44

55
class SendData extends Message
66
{
7+
public ?string $event_id = null;
78
}

src/DTO/Output.php

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,23 @@ class Output
1212

1313
public array $data = [];
1414

15+
private array $_extra = [];
16+
17+
public function __set(string $name, $value): void
18+
{
19+
$this->_extra[$name] = $value;
20+
}
21+
22+
public function __get(string $name)
23+
{
24+
return $this->_extra[$name] ?? null;
25+
}
26+
27+
public function __isset(string $name): bool
28+
{
29+
return isset($this->_extra[$name]);
30+
}
31+
1532
public function getOutputs(): array
1633
{
1734
return $this->data ?? [];

tests/ExampleTest.php

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,29 +24,22 @@
2424
});
2525

2626
it('can test another model', function () {
27-
$client = new Client('https://ysharma-explore-llamav2-with-tgi.hf.space/--replicas/brc3o/');
27+
$client = new Client('https://ehristoforu-mixtral-46-7b-chat.hf.space');
2828

29-
$response = $client->predict([
30-
'list all names of the week in all languages', // str in 'parameter_28' Textbox component
31-
'', // str in 'Optional system prompt' Textbox component
32-
0.9, // float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
33-
4096, // float (numeric value between 0 and 4096) in 'Max new tokens' Slider component
34-
0.6, // float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component
35-
1.2, // float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component
36-
], '/chat');
29+
$client->predict([], fnIndex: 5, raw: true);
30+
$client->predict(['hi'], fnIndex: 1, raw: true);
31+
$client->predict([null, []], fnIndex: 2, raw: true);
32+
$response = $client->predict([null, null, "", 0.9, 256, 0.9, 1.2], fnIndex: 3);
33+
$client->predict([], fnIndex: 6, raw: true);
3734

3835
$outputs = $response->getOutputs();
3936

4037
expect($client)->toBeInstanceOf(Client::class);
4138
});
4239

4340
it('can test fnindexsudgugdhs', function () {
44-
$client = new Client('https://ysharma-explore-llamav2-with-tgi.hf.space/--replicas/brc3o/');
45-
46-
$client->predict([], fnIndex: 6, raw: true);
47-
$client->predict(['hi'], fnIndex: 2, raw: true);
48-
$client->predict([null, null], fnIndex: 3, raw: true);
49-
$response = $client->predict([null, null, '', 0.9, 256, 0.6, 1.2], fnIndex: 4);
41+
$client = new Client('https://deepseek-ai-deepseek-vl-7b.hf.space');
42+
$response = $client->predict([[["Hello!", null]], 0, 0, 0, 0, 0, 'DeepSeek-VL 7B'], apiName: '/predict');
5043

5144
$value = $response->getOutput();
5245

0 commit comments

Comments
 (0)