diff --git a/.changeset/warm-avian-provider.md b/.changeset/warm-avian-provider.md new file mode 100644 index 0000000000..4419363e63 --- /dev/null +++ b/.changeset/warm-avian-provider.md @@ -0,0 +1,5 @@ +--- +"@huggingface/inference": minor +--- + +Add Avian as a new AI provider diff --git a/packages/inference/README.md b/packages/inference/README.md index 6f139e87d3..b7848479dd 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -47,6 +47,7 @@ Your access token should be kept private. If you need to protect it in front-end You can send inference requests to third-party providers with the inference client. Currently, we support the following providers: +- [Avian](https://avian.io) - [Fal.ai](https://fal.ai) - [Featherless AI](https://featherless.ai) - [Fireworks AI](https://fireworks.ai) @@ -90,6 +91,7 @@ When authenticated with a Hugging Face access token, the request is routed throu When authenticated with a third-party provider key, the request is made directly against that provider's inference API. Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here: +- [Avian supported models](https://huggingface.co/api/partners/avian/models) - [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models) - [Featherless AI supported models](https://huggingface.co/api/partners/featherless-ai/models) - [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index df7a943239..2ad2ef6761 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -1,3 +1,4 @@ +import * as Avian from "../providers/avian.js"; import * as Baseten from "../providers/baseten.js"; import * as Clarifai from "../providers/clarifai.js"; import * as BlackForestLabs from "../providers/black-forest-labs.js"; @@ -62,6 +63,9 @@ import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from import { InferenceClientInputError } from "../errors.js"; export const PROVIDERS: Record>> = { + avian: { + conversational: new Avian.AvianConversationalTask(), + }, baseten: { conversational: new Baseten.BasetenConversationalTask(), }, diff --git a/packages/inference/src/providers/avian.ts b/packages/inference/src/providers/avian.ts new file mode 100644 index 0000000000..eadb5e5832 --- /dev/null +++ b/packages/inference/src/providers/avian.ts @@ -0,0 +1,24 @@ +/** + * See the registered mapping of HF model ID => Avian model ID here: + * + * https://huggingface.co/api/partners/avian/models + * + * This is a publicly available mapping. + * + * If you want to try to run inference for a new model locally before it's registered on huggingface.co, + * you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes. + * + * - If you work at Avian and want to update this mapping, please use the model mapping API we provide on huggingface.co + * - If you're a community member and want to add a new supported HF model to Avian, please open an issue on the present repo + * and we will tag Avian team members. + * + * Thanks! + */ + +import { BaseConversationalTask } from "./providerHelper.js"; + +export class AvianConversationalTask extends BaseConversationalTask { + constructor() { + super("avian", "https://api.avian.io"); + } +} diff --git a/packages/inference/src/providers/consts.ts b/packages/inference/src/providers/consts.ts index 8e7a4ab498..34f58e4063 100644 --- a/packages/inference/src/providers/consts.ts +++ b/packages/inference/src/providers/consts.ts @@ -18,6 +18,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record< * Example: * "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct", */ + avian: {}, baseten: {}, "black-forest-labs": {}, cerebras: {}, diff --git a/packages/inference/src/types.ts b/packages/inference/src/types.ts index ebd100e8a7..84439c3f34 100644 --- a/packages/inference/src/types.ts +++ b/packages/inference/src/types.ts @@ -45,6 +45,7 @@ export interface Options { export type InferenceTask = Exclude | "conversational"; export const INFERENCE_PROVIDERS = [ + "avian", "baseten", "black-forest-labs", "cerebras", @@ -84,6 +85,7 @@ export type InferenceProviderOrPolicy = (typeof PROVIDERS_OR_POLICIES)[number]; * Whenever possible, InferenceProvider should == org namespace */ export const PROVIDERS_HUB_ORGS: Record = { + avian: "aviandata", baseten: "baseten", "black-forest-labs": "black-forest-labs", cerebras: "cerebras", diff --git a/packages/inference/test/avian.test.ts b/packages/inference/test/avian.test.ts new file mode 100644 index 0000000000..c39ef8d153 --- /dev/null +++ b/packages/inference/test/avian.test.ts @@ -0,0 +1,102 @@ +import { describe, it, expect } from "vitest"; +import type { ChatCompletionStreamOutput } from "@huggingface/tasks"; +import { InferenceClient, chatCompletion, chatCompletionStream } from "../src/index.js"; +import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../src/providers/consts.js"; + +const TIMEOUT = 60000 * 3; +const env = import.meta.env; + +if (!env.HF_TOKEN) { + console.warn("Set HF_TOKEN in the env to run the tests for better rate limits"); +} + +describe.skip.concurrent( + "Avian", + () => { + const client = new InferenceClient(env.HF_AVIAN_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["avian"] = { + "zai-org/GLM-4.6": { + provider: "avian", + hfModelId: "zai-org/GLM-4.6", + providerId: "glm-4.6", + status: "live", + task: "conversational", + }, + }; + + it("chatCompletion", async () => { + const res = await client.chatCompletion({ + model: "zai-org/GLM-4.6", + provider: "avian", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chatCompletion stream", async () => { + const stream = client.chatCompletionStream({ + model: "zai-org/GLM-4.6", + provider: "avian", + messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + expect(fullResponse).toMatch(/(two|2)/i); + }); + + it("chatCompletion - using chatCompletion function", async () => { + const res = await chatCompletion({ + accessToken: env.HF_AVIAN_KEY ?? "dummy", + model: "zai-org/GLM-4.6", + provider: "avian", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + temperature: 0.1, + }); + + expect(res).toBeDefined(); + expect(res.choices).toBeDefined(); + expect(res.choices?.length).toBeGreaterThan(0); + + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toBeDefined(); + expect(typeof completion).toBe("string"); + expect(completion).toMatch(/(two|2)/i); + } + }); + + it("chatCompletion stream - using chatCompletionStream function", async () => { + const stream = chatCompletionStream({ + accessToken: env.HF_AVIAN_KEY ?? "dummy", + model: "zai-org/GLM-4.6", + provider: "avian", + messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }], + }) as AsyncGenerator; + + let out = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + out += chunk.choices[0].delta.content; + } + } + expect(out).toMatch(/(two|2)/i); + }); + }, + TIMEOUT +);