Skip to content

Commit 9dcc581

Browse files
committed
fix: restore aisdk structured output options
1 parent 3a167a5 commit 9dcc581

3 files changed

Lines changed: 184 additions & 1 deletion

File tree

packages/core/lib/v3/external_clients/aisdk.ts

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@ import { CreateChatCompletionOptions, LLMClient } from "../llm/LLMClient.js";
1414
import { AvailableModel } from "../types/public/index.js";
1515
import { ChatCompletion } from "openai/resources";
1616

17+
function getGenerationSettings(
18+
options: CreateChatCompletionOptions["options"],
19+
) {
20+
return {
21+
temperature: options.temperature,
22+
maxOutputTokens: options.maxOutputTokens,
23+
topP: options.top_p,
24+
frequencyPenalty: options.frequency_penalty,
25+
presencePenalty: options.presence_penalty,
26+
};
27+
}
28+
1729
export class AISdkClient extends LLMClient {
1830
public type = "aisdk" as const;
1931
private model: LanguageModelV2;
@@ -86,6 +98,7 @@ export class AISdkClient extends LLMClient {
8698
model: this.model,
8799
messages: formattedMessages,
88100
schema: options.response_model.schema,
101+
...getGenerationSettings(options),
89102
});
90103

91104
return {
@@ -102,7 +115,7 @@ export class AISdkClient extends LLMClient {
102115

103116
const tools: Record<string, Tool> = {};
104117

105-
for (const rawTool of options.tools) {
118+
for (const rawTool of options.tools ?? []) {
106119
tools[rawTool.name] = {
107120
description: rawTool.description,
108121
inputSchema: rawTool.parameters,
@@ -113,6 +126,7 @@ export class AISdkClient extends LLMClient {
113126
model: this.model,
114127
messages: formattedMessages,
115128
tools,
129+
...getGenerationSettings(options),
116130
});
117131

118132
return {

packages/core/lib/v3/llm/aisdk.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,21 @@ export class AISdkClient extends LLMClient {
174174
: {}),
175175
};
176176
break;
177+
case "azure":
178+
providerOptions.azure = {
179+
strictJsonSchema: true,
180+
};
181+
break;
177182
case "google":
178183
providerOptions.google = {
179184
structuredOutputs: true,
180185
};
181186
break;
187+
case "vertex":
188+
providerOptions.vertex = {
189+
structuredOutputs: true,
190+
};
191+
break;
182192
case "anthropic":
183193
providerOptions.anthropic = {
184194
structuredOutputMode: "auto",
@@ -189,6 +199,11 @@ export class AISdkClient extends LLMClient {
189199
structuredOutputs: true,
190200
};
191201
break;
202+
case "cerebras":
203+
providerOptions.cerebras = {
204+
strictJsonSchema: true,
205+
};
206+
break;
192207
case "mistral":
193208
providerOptions.mistral = {
194209
structuredOutputs: true,
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import type { LanguageModelV2 } from "@ai-sdk/provider";
2+
import { generateObject, generateText } from "ai";
3+
import { z } from "zod";
4+
import { beforeEach, describe, expect, it, vi } from "vitest";
5+
import { AISdkClient as ExternalAISdkClient } from "../../lib/v3/external_clients/aisdk.js";
6+
import { AISdkClient as LlmAISdkClient } from "../../lib/v3/llm/aisdk.js";
7+
8+
vi.mock("ai", async () => {
9+
const actual = await vi.importActual<typeof import("ai")>("ai");
10+
return {
11+
...actual,
12+
generateObject: vi.fn(),
13+
generateText: vi.fn(),
14+
};
15+
});
16+
17+
const mockGenerateObject = vi.mocked(generateObject);
18+
const mockGenerateText = vi.mocked(generateText);
19+
20+
function createModel(modelId = "openai/gpt-4.1") {
21+
return {
22+
modelId,
23+
specificationVersion: "v2",
24+
} as unknown as LanguageModelV2;
25+
}
26+
27+
describe("AISdk clients", () => {
28+
beforeEach(() => {
29+
mockGenerateObject.mockReset();
30+
mockGenerateText.mockReset();
31+
});
32+
33+
it("external AISdkClient forwards generation settings to generateObject", async () => {
34+
mockGenerateObject.mockResolvedValue({
35+
object: { ok: true },
36+
usage: {
37+
inputTokens: 1,
38+
outputTokens: 2,
39+
reasoningTokens: 0,
40+
cachedInputTokens: 0,
41+
totalTokens: 3,
42+
},
43+
} as never);
44+
45+
const client = new ExternalAISdkClient({
46+
model: createModel(),
47+
});
48+
49+
await client.createChatCompletion({
50+
options: {
51+
messages: [{ role: "user", content: "hello" }],
52+
response_model: {
53+
name: "test",
54+
schema: z.object({ ok: z.boolean() }),
55+
},
56+
temperature: 0.2,
57+
maxOutputTokens: 50,
58+
top_p: 0.8,
59+
frequency_penalty: 0.1,
60+
presence_penalty: 0.3,
61+
},
62+
logger: vi.fn(),
63+
});
64+
65+
expect(mockGenerateObject).toHaveBeenCalledWith(
66+
expect.objectContaining({
67+
temperature: 0.2,
68+
maxOutputTokens: 50,
69+
topP: 0.8,
70+
frequencyPenalty: 0.1,
71+
presencePenalty: 0.3,
72+
}),
73+
);
74+
});
75+
76+
it("external AISdkClient handles missing tools without throwing", async () => {
77+
mockGenerateText.mockResolvedValue({
78+
text: "ok",
79+
usage: {
80+
inputTokens: 1,
81+
outputTokens: 1,
82+
reasoningTokens: 0,
83+
cachedInputTokens: 0,
84+
totalTokens: 2,
85+
},
86+
} as never);
87+
88+
const client = new ExternalAISdkClient({
89+
model: createModel(),
90+
});
91+
92+
await expect(
93+
client.createChatCompletion({
94+
options: {
95+
messages: [{ role: "user", content: "hello" }],
96+
},
97+
logger: vi.fn(),
98+
}),
99+
).resolves.toEqual(
100+
expect.objectContaining({
101+
data: "ok",
102+
}),
103+
);
104+
105+
expect(mockGenerateText).toHaveBeenCalledWith(
106+
expect.objectContaining({
107+
tools: {},
108+
}),
109+
);
110+
});
111+
112+
it.each([
113+
["azure", { azure: { strictJsonSchema: true } }],
114+
["vertex", { vertex: { structuredOutputs: true } }],
115+
["cerebras", { cerebras: { strictJsonSchema: true } }],
116+
])(
117+
"main AISdkClient preserves structured provider options for %s",
118+
async (providerName, providerOptions) => {
119+
mockGenerateObject.mockResolvedValue({
120+
object: { ok: true },
121+
usage: {
122+
inputTokens: 1,
123+
outputTokens: 2,
124+
reasoningTokens: 0,
125+
cachedInputTokens: 0,
126+
totalTokens: 3,
127+
},
128+
} as never);
129+
130+
const client = new LlmAISdkClient({
131+
model: createModel(`${providerName}/model`),
132+
providerName,
133+
logger: vi.fn(),
134+
});
135+
136+
await client.createChatCompletion({
137+
options: {
138+
messages: [{ role: "user", content: "hello" }],
139+
response_model: {
140+
name: "test",
141+
schema: z.object({ ok: z.boolean() }),
142+
},
143+
},
144+
logger: vi.fn(),
145+
});
146+
147+
expect(mockGenerateObject).toHaveBeenCalledWith(
148+
expect.objectContaining({
149+
providerOptions,
150+
}),
151+
);
152+
},
153+
);
154+
});

0 commit comments

Comments
 (0)