Skip to content

Commit 080df08

Browse files
bghirabghira
authored andcommitted
add missing model catalogue search parameters (search, task, author, source) (#13901)
Co-authored-by: bghira <bghira@users.github.com>
1 parent d4c7ae5 commit 080df08

7 files changed

Lines changed: 336 additions & 40 deletions

File tree

.changeset/ai-models-schema.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"wrangler": minor
3+
---
4+
5+
Add `wrangler ai models schema` command for fetching model schemas
6+
7+
You can now run `wrangler ai models schema <model>` to fetch the input and output schema for a Workers AI model from the public model catalog schema endpoint.

.changeset/public-model-search.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
"wrangler": minor
3+
---
4+
5+
Add `wrangler ai models list` command for querying the Workers AI model catalog
6+
7+
`wrangler ai models list` accepts `--search`, `--task`, `--author`, `--source`, and `--hide-experimental`, matching the public model catalog search endpoint.

packages/wrangler/src/__tests__/ai.test.ts

Lines changed: 152 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ describe("ai help", () => {
2323
🤖 Manage AI models
2424
2525
COMMANDS
26-
wrangler ai models List catalog models
26+
wrangler ai models Manage AI models
2727
wrangler ai finetune Interact with finetune files
2828
2929
GLOBAL FLAGS
@@ -55,7 +55,7 @@ describe("ai help", () => {
5555
🤖 Manage AI models
5656
5757
COMMANDS
58-
wrangler ai models List catalog models
58+
wrangler ai models Manage AI models
5959
wrangler ai finetune Interact with finetune files
6060
6161
GLOBAL FLAGS
@@ -67,6 +67,51 @@ describe("ai help", () => {
6767
-v, --version Show version number [boolean]"
6868
`);
6969
});
70+
71+
it("should show models help", async ({ expect }) => {
72+
await runWrangler("ai models --help");
73+
await endEventLoop();
74+
75+
expect(std.out).toMatchInlineSnapshot(`
76+
"wrangler ai models
77+
78+
Manage AI models
79+
80+
COMMANDS
81+
wrangler ai models list List catalog models
82+
wrangler ai models schema <model> Get model schema
83+
84+
GLOBAL FLAGS
85+
-c, --config Path to Wrangler configuration file [string]
86+
--cwd Run as if Wrangler was started in the specified directory instead of the current working directory [string]
87+
-e, --env Environment to use for operations, and for selecting .env and .dev.vars files [string]
88+
--env-file Path to an .env file to load - can be specified multiple times - values from earlier files are overridden by values in later files [array]
89+
-h, --help Show help [boolean]
90+
-v, --version Show version number [boolean]"
91+
`);
92+
});
93+
94+
it("should show schema help without model list flags", async ({ expect }) => {
95+
await runWrangler("ai models schema --help");
96+
await endEventLoop();
97+
98+
expect(std.out).toMatchInlineSnapshot(`
99+
"wrangler ai models schema <model>
100+
101+
Get model schema
102+
103+
POSITIONALS
104+
model The model to fetch a schema for [string] [required]
105+
106+
GLOBAL FLAGS
107+
-c, --config Path to Wrangler configuration file [string]
108+
--cwd Run as if Wrangler was started in the specified directory instead of the current working directory [string]
109+
-e, --env Environment to use for operations, and for selecting .env and .dev.vars files [string]
110+
--env-file Path to an .env file to load - can be specified multiple times - values from earlier files are overridden by values in later files [array]
111+
-h, --help Show help [boolean]
112+
-v, --version Show version number [boolean]"
113+
`);
114+
});
70115
});
71116

72117
describe("ai commands", () => {
@@ -111,6 +156,23 @@ describe("ai commands", () => {
111156
});
112157

113158
it("should handle model list", async ({ expect }) => {
159+
mockAISearchRequest();
160+
await runWrangler("ai models list");
161+
expect(std.out).toMatchInlineSnapshot(`
162+
"
163+
⛅️ wrangler x.x.x
164+
──────────────────
165+
┌─┬─┬─┬─┐
166+
│ model │ name │ description │ task │
167+
├─┼─┼─┼─┤
168+
│ 429b9e8b-d99e-44de-91ad-706cf8183658 │ @cloudflare/embeddings_bge_large_en │ │ │
169+
├─┼─┼─┼─┤
170+
│ 7f9a76e1-d120-48dd-a565-101d328bbb02 │ @cloudflare/resnet50 │ │ Image Classification │
171+
└─┴─┴─┴─┘"
172+
`);
173+
});
174+
175+
it("should handle legacy model list", async ({ expect }) => {
114176
mockAISearchRequest();
115177
await runWrangler("ai models");
116178
expect(std.out).toMatchInlineSnapshot(`
@@ -127,13 +189,59 @@ describe("ai commands", () => {
127189
`);
128190
});
129191

192+
it("should query model list with filters", async ({ expect }) => {
193+
const requests = mockAISearchRequest();
194+
await runWrangler(
195+
'ai models list --search resnet --task "Image Classification" --author cloudflare --source 1 --hide-experimental --json'
196+
);
197+
198+
expect(requests).toHaveLength(1);
199+
const searchParams = new URL(requests[0].url).searchParams;
200+
expect(searchParams.get("per_page")).toBe("50");
201+
expect(searchParams.get("page")).toBe("1");
202+
expect(searchParams.get("search")).toBe("resnet");
203+
expect(searchParams.get("task")).toBe("Image Classification");
204+
expect(searchParams.get("author")).toBe("cloudflare");
205+
expect(searchParams.get("source")).toBe("1");
206+
expect(searchParams.get("hide_experimental")).toBe("true");
207+
expect(std.out).toContain("@cloudflare/resnet50");
208+
});
209+
210+
it("should handle model schema", async ({ expect }) => {
211+
const requests = mockAISchemaRequest();
212+
await runWrangler('ai models schema "@cloudflare/resnet50"');
213+
214+
expect(requests).toHaveLength(1);
215+
const searchParams = new URL(requests[0].url).searchParams;
216+
expect(searchParams.get("model")).toBe("@cloudflare/resnet50");
217+
expect(std.out).toMatchInlineSnapshot(`
218+
"{
219+
"input": {
220+
"type": "object",
221+
"properties": {
222+
"image": {
223+
"type": "string",
224+
"format": "binary"
225+
}
226+
}
227+
},
228+
"output": {
229+
"type": "array",
230+
"items": {
231+
"type": "object"
232+
}
233+
}
234+
}"
235+
`);
236+
});
237+
130238
it("should truncate model description", async ({ expect }) => {
131239
const original = process.stdout.columns;
132240
// Arbitrary fixed value for testing
133241
process.stdout.columns = 186;
134242

135243
mockAIOverflowRequest();
136-
await runWrangler("ai models");
244+
await runWrangler("ai models list");
137245
expect(std.out).toMatchInlineSnapshot(`
138246
"
139247
⛅️ wrangler x.x.x
@@ -154,7 +262,7 @@ describe("ai commands", () => {
154262
// Arbitrary fixed value for testing
155263
process.stdout.columns = 186;
156264
mockAIPaginatedRequest();
157-
await runWrangler("ai models");
265+
await runWrangler("ai models list");
158266
expect(std.out).toMatchInlineSnapshot(`
159267
"
160268
⛅️ wrangler x.x.x
@@ -321,10 +429,12 @@ function mockAIListFinetuneRequest() {
321429
}
322430

323431
function mockAISearchRequest() {
432+
const requests: Request[] = [];
324433
msw.use(
325434
http.get(
326435
"*/accounts/:accountId/ai/models/search",
327-
() => {
436+
({ request }) => {
437+
requests.push(request);
328438
return HttpResponse.json(
329439
createFetchResult(
330440
[
@@ -356,6 +466,43 @@ function mockAISearchRequest() {
356466
{ once: true }
357467
)
358468
);
469+
return requests;
470+
}
471+
472+
function mockAISchemaRequest() {
473+
const requests: Request[] = [];
474+
msw.use(
475+
http.get(
476+
"*/accounts/:accountId/ai/models/schema",
477+
({ request }) => {
478+
requests.push(request);
479+
return HttpResponse.json(
480+
createFetchResult(
481+
{
482+
input: {
483+
type: "object",
484+
properties: {
485+
image: {
486+
type: "string",
487+
format: "binary",
488+
},
489+
},
490+
},
491+
output: {
492+
type: "array",
493+
items: {
494+
type: "object",
495+
},
496+
},
497+
},
498+
true
499+
)
500+
);
501+
},
502+
{ once: true }
503+
)
504+
);
505+
return requests;
359506
}
360507

361508
function mockAIOverflowRequest() {

packages/wrangler/src/ai/listCatalog.ts

Lines changed: 95 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,32 @@ import { createCommand } from "../core/create-command";
22
import { logger } from "../logger";
33
import { requireAuth } from "../user";
44
import { listCatalogEntries, truncateDescription } from "./utils";
5+
import type { Config } from "@cloudflare/workers-utils";
6+
7+
type ListModelsArgs = {
8+
author?: string;
9+
hideExperimental?: boolean;
10+
json?: boolean;
11+
search?: string;
12+
source?: number;
13+
task?: string;
14+
};
515

616
export const aiModelsCommand = createCommand({
17+
metadata: {
18+
description: "Manage AI models",
19+
status: "stable",
20+
owner: "Product: AI",
21+
},
22+
behaviour: {
23+
printBanner: true,
24+
},
25+
async handler(_args, { config }) {
26+
await listModels({}, config);
27+
},
28+
});
29+
30+
export const aiModelsListCommand = createCommand({
731
metadata: {
832
description: "List catalog models",
933
status: "stable",
@@ -18,32 +42,77 @@ export const aiModelsCommand = createCommand({
1842
description: "Return output as JSON",
1943
default: false,
2044
},
45+
search: {
46+
type: "string",
47+
description: "Search models by name or description",
48+
},
49+
task: {
50+
type: "string",
51+
description: "Filter by task name",
52+
},
53+
author: {
54+
type: "string",
55+
description: "Filter by author",
56+
},
57+
source: {
58+
type: "number",
59+
description: "Filter by source ID",
60+
},
61+
"hide-experimental": {
62+
type: "boolean",
63+
description: "Hide experimental models",
64+
default: false,
65+
},
2166
},
22-
async handler({ json }, { config }) {
23-
const accountId = await requireAuth(config);
24-
const entries = await listCatalogEntries(config, accountId);
25-
26-
if (json) {
27-
logger.log(JSON.stringify(entries, null, 2));
28-
} else {
29-
if (entries.length === 0) {
30-
logger.log(`No models found.`);
31-
} else {
32-
logger.table(
33-
entries.map((entry) => ({
34-
model: entry.id,
35-
name: entry.name,
36-
description: truncateDescription(
37-
entry.description,
38-
entry.id.length +
39-
entry.name.length +
40-
(entry.task ? entry.task.name.length : 0) +
41-
10
42-
),
43-
task: entry.task ? entry.task.name : "",
44-
}))
45-
);
46-
}
47-
}
67+
async handler(
68+
{ author, hideExperimental, json, search, source, task },
69+
{ config }
70+
) {
71+
await listModels(
72+
{ author, hideExperimental, json, search, source, task },
73+
config
74+
);
4875
},
4976
});
77+
78+
async function listModels(
79+
{
80+
author,
81+
hideExperimental,
82+
json = false,
83+
search,
84+
source,
85+
task,
86+
}: ListModelsArgs,
87+
config: Config
88+
) {
89+
const accountId = await requireAuth(config);
90+
const entries = await listCatalogEntries(config, accountId, {
91+
author,
92+
hideExperimental,
93+
search,
94+
source,
95+
task,
96+
});
97+
98+
if (json) {
99+
logger.json(entries);
100+
} else if (entries.length === 0) {
101+
logger.log(`No models found.`);
102+
} else {
103+
logger.table(
104+
entries.map((entry) => ({
105+
model: entry.id,
106+
name: entry.name,
107+
description: truncateDescription(
108+
entry.description,
109+
entry.id.length +
110+
entry.name.length +
111+
(entry.task ? entry.task.name.length : 0) +
112+
10
113+
),
114+
task: entry.task ? entry.task.name : "",
115+
}))
116+
);
117+
}
118+
}

0 commit comments

Comments
 (0)