Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/ai-models-schema.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"wrangler": minor
---

Add `wrangler ai models schema` command for fetching model schemas

You can now run `wrangler ai models schema <model>` to fetch the input and output schema for a Workers AI model from the public model catalog schema endpoint.
7 changes: 7 additions & 0 deletions .changeset/public-model-search.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"wrangler": minor
---

Add `wrangler ai models list` command for querying the Workers AI model catalog

`wrangler ai models list` accepts `--search`, `--task`, `--author`, `--source`, and `--hide-experimental`, matching the public model catalog search endpoint.
157 changes: 152 additions & 5 deletions packages/wrangler/src/__tests__/ai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ describe("ai help", () => {
🤖 Manage AI models

COMMANDS
wrangler ai models List catalog models
wrangler ai models Manage AI models
wrangler ai finetune Interact with finetune files

GLOBAL FLAGS
Expand Down Expand Up @@ -55,7 +55,7 @@ describe("ai help", () => {
🤖 Manage AI models

COMMANDS
wrangler ai models List catalog models
wrangler ai models Manage AI models
wrangler ai finetune Interact with finetune files

GLOBAL FLAGS
Expand All @@ -67,6 +67,51 @@ describe("ai help", () => {
-v, --version Show version number [boolean]"
`);
});

it("should show models help", async ({ expect }) => {
await runWrangler("ai models --help");
await endEventLoop();

expect(std.out).toMatchInlineSnapshot(`
"wrangler ai models

Manage AI models

COMMANDS
wrangler ai models list List catalog models
wrangler ai models schema <model> Get model schema

GLOBAL FLAGS
-c, --config Path to Wrangler configuration file [string]
--cwd Run as if Wrangler was started in the specified directory instead of the current working directory [string]
-e, --env Environment to use for operations, and for selecting .env and .dev.vars files [string]
--env-file Path to an .env file to load - can be specified multiple times - values from earlier files are overridden by values in later files [array]
-h, --help Show help [boolean]
-v, --version Show version number [boolean]"
`);
});

it("should show schema help without model list flags", async ({ expect }) => {
await runWrangler("ai models schema --help");
await endEventLoop();

expect(std.out).toMatchInlineSnapshot(`
"wrangler ai models schema <model>

Get model schema

POSITIONALS
model The model to fetch a schema for [string] [required]

GLOBAL FLAGS
-c, --config Path to Wrangler configuration file [string]
--cwd Run as if Wrangler was started in the specified directory instead of the current working directory [string]
-e, --env Environment to use for operations, and for selecting .env and .dev.vars files [string]
--env-file Path to an .env file to load - can be specified multiple times - values from earlier files are overridden by values in later files [array]
-h, --help Show help [boolean]
-v, --version Show version number [boolean]"
`);
});
});

describe("ai commands", () => {
Expand Down Expand Up @@ -111,6 +156,23 @@ describe("ai commands", () => {
});

it("should handle model list", async ({ expect }) => {
mockAISearchRequest();
await runWrangler("ai models list");
expect(std.out).toMatchInlineSnapshot(`
"
⛅️ wrangler x.x.x
──────────────────
┌─┬─┬─┬─┐
│ model │ name │ description │ task │
├─┼─┼─┼─┤
│ 429b9e8b-d99e-44de-91ad-706cf8183658 │ @cloudflare/embeddings_bge_large_en │ │ │
├─┼─┼─┼─┤
│ 7f9a76e1-d120-48dd-a565-101d328bbb02 │ @cloudflare/resnet50 │ │ Image Classification │
└─┴─┴─┴─┘"
`);
});

it("should handle legacy model list", async ({ expect }) => {
mockAISearchRequest();
await runWrangler("ai models");
expect(std.out).toMatchInlineSnapshot(`
Expand All @@ -127,13 +189,59 @@ describe("ai commands", () => {
`);
});

it("should query model list with filters", async ({ expect }) => {
const requests = mockAISearchRequest();
await runWrangler(
'ai models list --search resnet --task "Image Classification" --author cloudflare --source 1 --hide-experimental --json'
);

expect(requests).toHaveLength(1);
const searchParams = new URL(requests[0].url).searchParams;
expect(searchParams.get("per_page")).toBe("50");
expect(searchParams.get("page")).toBe("1");
expect(searchParams.get("search")).toBe("resnet");
expect(searchParams.get("task")).toBe("Image Classification");
expect(searchParams.get("author")).toBe("cloudflare");
expect(searchParams.get("source")).toBe("1");
expect(searchParams.get("hide_experimental")).toBe("true");
expect(std.out).toContain("@cloudflare/resnet50");
});

it("should handle model schema", async ({ expect }) => {
const requests = mockAISchemaRequest();
await runWrangler('ai models schema "@cloudflare/resnet50"');

expect(requests).toHaveLength(1);
const searchParams = new URL(requests[0].url).searchParams;
expect(searchParams.get("model")).toBe("@cloudflare/resnet50");
expect(std.out).toMatchInlineSnapshot(`
"{
"input": {
"type": "object",
"properties": {
"image": {
"type": "string",
"format": "binary"
}
}
},
"output": {
"type": "array",
"items": {
"type": "object"
}
}
}"
`);
});

it("should truncate model description", async ({ expect }) => {
const original = process.stdout.columns;
// Arbitrary fixed value for testing
process.stdout.columns = 186;

mockAIOverflowRequest();
await runWrangler("ai models");
await runWrangler("ai models list");
expect(std.out).toMatchInlineSnapshot(`
"
⛅️ wrangler x.x.x
Expand All @@ -154,7 +262,7 @@ describe("ai commands", () => {
// Arbitrary fixed value for testing
process.stdout.columns = 186;
mockAIPaginatedRequest();
await runWrangler("ai models");
await runWrangler("ai models list");
expect(std.out).toMatchInlineSnapshot(`
"
⛅️ wrangler x.x.x
Expand Down Expand Up @@ -321,10 +429,12 @@ function mockAIListFinetuneRequest() {
}

function mockAISearchRequest() {
const requests: Request[] = [];
msw.use(
http.get(
"*/accounts/:accountId/ai/models/search",
() => {
({ request }) => {
requests.push(request);
return HttpResponse.json(
createFetchResult(
[
Expand Down Expand Up @@ -356,6 +466,43 @@ function mockAISearchRequest() {
{ once: true }
)
);
return requests;
}

function mockAISchemaRequest() {
const requests: Request[] = [];
msw.use(
http.get(
"*/accounts/:accountId/ai/models/schema",
({ request }) => {
requests.push(request);
return HttpResponse.json(
createFetchResult(
{
input: {
type: "object",
properties: {
image: {
type: "string",
format: "binary",
},
},
},
output: {
type: "array",
items: {
type: "object",
},
},
},
true
)
);
},
{ once: true }
)
);
return requests;
}

function mockAIOverflowRequest() {
Expand Down
121 changes: 95 additions & 26 deletions packages/wrangler/src/ai/listCatalog.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,32 @@ import { createCommand } from "../core/create-command";
import { logger } from "../logger";
import { requireAuth } from "../user";
import { listCatalogEntries, truncateDescription } from "./utils";
import type { Config } from "@cloudflare/workers-utils";

type ListModelsArgs = {
author?: string;
hideExperimental?: boolean;
json?: boolean;
search?: string;
source?: number;
task?: string;
};

export const aiModelsCommand = createCommand({
metadata: {
description: "Manage AI models",
status: "stable",
owner: "Product: AI",
},
behaviour: {
printBanner: true,
},
async handler(_args, { config }) {
await listModels({}, config);
},
});

export const aiModelsListCommand = createCommand({
metadata: {
description: "List catalog models",
status: "stable",
Expand All @@ -18,32 +42,77 @@ export const aiModelsCommand = createCommand({
description: "Return output as JSON",
default: false,
},
search: {
type: "string",
description: "Search models by name or description",
},
task: {
type: "string",
description: "Filter by task name",
},
author: {
type: "string",
description: "Filter by author",
},
source: {
type: "number",
description: "Filter by source ID",
},
"hide-experimental": {
type: "boolean",
description: "Hide experimental models",
default: false,
},
},
async handler({ json }, { config }) {
const accountId = await requireAuth(config);
const entries = await listCatalogEntries(config, accountId);

if (json) {
logger.log(JSON.stringify(entries, null, 2));
} else {
if (entries.length === 0) {
logger.log(`No models found.`);
} else {
logger.table(
entries.map((entry) => ({
model: entry.id,
name: entry.name,
description: truncateDescription(
entry.description,
entry.id.length +
entry.name.length +
(entry.task ? entry.task.name.length : 0) +
10
),
task: entry.task ? entry.task.name : "",
}))
);
}
}
async handler(
{ author, hideExperimental, json, search, source, task },
{ config }
) {
await listModels(
{ author, hideExperimental, json, search, source, task },
config
);
},
});

async function listModels(
{
author,
hideExperimental,
json = false,
search,
source,
task,
}: ListModelsArgs,
config: Config
) {
const accountId = await requireAuth(config);
const entries = await listCatalogEntries(config, accountId, {
author,
hideExperimental,
search,
source,
task,
});

if (json) {
logger.json(entries);
} else if (entries.length === 0) {
logger.log(`No models found.`);
} else {
logger.table(
entries.map((entry) => ({
model: entry.id,
name: entry.name,
description: truncateDescription(
entry.description,
entry.id.length +
entry.name.length +
(entry.task ? entry.task.name.length : 0) +
10
),
task: entry.task ? entry.task.name : "",
}))
);
}
}
Loading
Loading