Skip to content

Commit 13caab3

Browse files
committed
feat: add Avian AI provider support
1 parent cf256e7 commit 13caab3

6 files changed

Lines changed: 139 additions & 0 deletions

File tree

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Your access token should be kept private. If you need to protect it in front-end
4747
You can send inference requests to third-party providers with the inference client.
4848

4949
Currently, we support the following providers:
50+
- [Avian](https://avian.io)
5051
- [Fal.ai](https://fal.ai)
5152
- [Featherless AI](https://featherless.ai)
5253
- [Fireworks AI](https://fireworks.ai)
@@ -90,6 +91,7 @@ When authenticated with a Hugging Face access token, the request is routed throu
9091
When authenticated with a third-party provider key, the request is made directly against that provider's inference API.
9192

9293
Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
94+
- [Avian supported models](https://huggingface.co/api/partners/avian/models)
9395
- [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
9496
- [Featherless AI supported models](https://huggingface.co/api/partners/featherless-ai/models)
9597
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import * as Avian from "../providers/avian.js";
12
import * as Baseten from "../providers/baseten.js";
23
import * as Clarifai from "../providers/clarifai.js";
34
import * as BlackForestLabs from "../providers/black-forest-labs.js";
@@ -62,6 +63,9 @@ import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from
6263
import { InferenceClientInputError } from "../errors.js";
6364

6465
export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
66+
avian: {
67+
conversational: new Avian.AvianConversationalTask(),
68+
},
6569
baseten: {
6670
conversational: new Baseten.BasetenConversationalTask(),
6771
},
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/**
2+
* See the registered mapping of HF model ID => Avian model ID here:
3+
*
4+
* https://huggingface.co/api/partners/avian/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at Avian and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to Avian, please open an issue on the present repo
13+
* and we will tag Avian team members.
14+
*
15+
* Thanks!
16+
*/
17+
18+
import { BaseConversationalTask } from "./providerHelper.js";
19+
20+
export class AvianConversationalTask extends BaseConversationalTask {
21+
constructor() {
22+
super("avian", "https://api.avian.io/v1");
23+
}
24+
25+
override makeRoute(): string {
26+
return "/chat/completions";
27+
}
28+
}

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
1818
* Example:
1919
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
2020
*/
21+
avian: {},
2122
baseten: {},
2223
"black-forest-labs": {},
2324
cerebras: {},

packages/inference/src/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ export interface Options {
4545
export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";
4646

4747
export const INFERENCE_PROVIDERS = [
48+
"avian",
4849
"baseten",
4950
"black-forest-labs",
5051
"cerebras",
@@ -84,6 +85,7 @@ export type InferenceProviderOrPolicy = (typeof PROVIDERS_OR_POLICIES)[number];
8485
* Whenever possible, InferenceProvider should == org namespace
8586
*/
8687
export const PROVIDERS_HUB_ORGS: Record<InferenceProvider, string> = {
88+
avian: "aviandata",
8789
baseten: "baseten",
8890
"black-forest-labs": "black-forest-labs",
8991
cerebras: "cerebras",
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import { describe, it, expect } from "vitest";
2+
import type { ChatCompletionStreamOutput } from "@huggingface/tasks";
3+
import { InferenceClient, chatCompletion, chatCompletionStream } from "../src/index.js";
4+
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../src/providers/consts.js";
5+
6+
const TIMEOUT = 60000 * 3;
7+
const env = import.meta.env;
8+
9+
if (!env.HF_TOKEN) {
10+
console.warn("Set HF_TOKEN in the env to run the tests for better rate limits");
11+
}
12+
13+
describe.skip.concurrent(
14+
"Avian",
15+
() => {
16+
const client = new InferenceClient(env.HF_AVIAN_KEY ?? "dummy");
17+
18+
HARDCODED_MODEL_INFERENCE_MAPPING["avian"] = {
19+
"zai-org/GLM-4.6": {
20+
provider: "avian",
21+
hfModelId: "zai-org/GLM-4.6",
22+
providerId: "glm-4.6",
23+
status: "live",
24+
task: "conversational",
25+
},
26+
};
27+
28+
it("chatCompletion", async () => {
29+
const res = await client.chatCompletion({
30+
model: "zai-org/GLM-4.6",
31+
provider: "avian",
32+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
33+
});
34+
if (res.choices && res.choices.length > 0) {
35+
const completion = res.choices[0].message?.content;
36+
expect(completion).toContain("two");
37+
}
38+
});
39+
40+
it("chatCompletion stream", async () => {
41+
const stream = client.chatCompletionStream({
42+
model: "zai-org/GLM-4.6",
43+
provider: "avian",
44+
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
45+
stream: true,
46+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
47+
48+
let fullResponse = "";
49+
for await (const chunk of stream) {
50+
if (chunk.choices && chunk.choices.length > 0) {
51+
const content = chunk.choices[0].delta?.content;
52+
if (content) {
53+
fullResponse += content;
54+
}
55+
}
56+
}
57+
58+
expect(fullResponse).toBeTruthy();
59+
expect(fullResponse.length).toBeGreaterThan(0);
60+
expect(fullResponse).toMatch(/(two|2)/i);
61+
});
62+
63+
it("chatCompletion - using chatCompletion function", async () => {
64+
const res = await chatCompletion({
65+
accessToken: env.HF_AVIAN_KEY ?? "dummy",
66+
model: "zai-org/GLM-4.6",
67+
provider: "avian",
68+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
69+
temperature: 0.1,
70+
});
71+
72+
expect(res).toBeDefined();
73+
expect(res.choices).toBeDefined();
74+
expect(res.choices?.length).toBeGreaterThan(0);
75+
76+
if (res.choices && res.choices.length > 0) {
77+
const completion = res.choices[0].message?.content;
78+
expect(completion).toBeDefined();
79+
expect(typeof completion).toBe("string");
80+
expect(completion).toMatch(/(two|2)/i);
81+
}
82+
});
83+
84+
it("chatCompletion stream - using chatCompletionStream function", async () => {
85+
const stream = chatCompletionStream({
86+
accessToken: env.HF_AVIAN_KEY ?? "dummy",
87+
model: "zai-org/GLM-4.6",
88+
provider: "avian",
89+
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
90+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
91+
92+
let out = "";
93+
for await (const chunk of stream) {
94+
if (chunk.choices && chunk.choices.length > 0) {
95+
out += chunk.choices[0].delta.content;
96+
}
97+
}
98+
expect(out).toMatch(/(two|2)/i);
99+
});
100+
},
101+
TIMEOUT
102+
);

0 commit comments

Comments
 (0)