Skip to content

Commit cfee788

Browse files
committed
feat(session): add native OpenAI runtime opt-in
1 parent b93cb0e commit cfee788

3 files changed

Lines changed: 241 additions & 79 deletions

File tree

packages/opencode/src/session/llm-native.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ type ToolInput = {
1111

1212
export type RequestInput = {
1313
readonly model: Provider.Model
14+
readonly apiKey?: string
15+
readonly baseURL?: string
1416
readonly system?: readonly string[]
1517
readonly messages: readonly ModelMessage[]
1618
readonly tools?: Record<string, ToolInput>
@@ -154,14 +156,16 @@ const baseURL = (model: Provider.Model) => {
154156
throw new Error(`Native LLM request adapter requires a base URL for ${model.providerID}/${model.id}`)
155157
}
156158

157-
export const model = (model: Provider.Model, headers?: Record<string, string>) => {
159+
export const model = (input: Provider.Model | RequestInput, headers?: Record<string, string>) => {
160+
const model = "model" in input ? input.model : input
158161
const route = ROUTE[model.api.npm]
159162
if (!route) throw new Error(`Native LLM request adapter does not support provider package ${model.api.npm}`)
160163
return LLM.model({
161164
id: model.api.id,
162165
provider: model.providerID,
163166
route,
164-
baseURL: baseURL(model),
167+
baseURL: "model" in input && input.baseURL ? input.baseURL : baseURL(model),
168+
apiKey: "model" in input ? input.apiKey : undefined,
165169
headers: Object.keys({ ...model.headers, ...headers }).length === 0 ? undefined : { ...model.headers, ...headers },
166170
limits: {
167171
context: model.limit.context,
@@ -173,7 +177,7 @@ export const model = (model: Provider.Model, headers?: Record<string, string>) =
173177
export const request = (input: RequestInput) => {
174178
const converted = messages(input.messages)
175179
return LLM.request({
176-
model: model(input.model, input.headers),
180+
model: model(input, input.headers),
177181
system: [...(input.system ?? []).map(SystemPart.make), ...converted.system],
178182
messages: converted.messages,
179183
tools: tools(input.tools),

packages/opencode/src/session/llm.ts

Lines changed: 119 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { Context, Effect, Layer, Record } from "effect"
44
import * as Stream from "effect/Stream"
55
import { streamText, wrapLanguageModel, type ModelMessage, type Tool, tool, jsonSchema } from "ai"
66
import type { LLMEvent } from "@opencode-ai/llm"
7+
import { LLMClient, RequestExecutor } from "@opencode-ai/llm/route"
78
import { mergeDeep } from "remeda"
89
import { GitLabWorkflowLanguageModel } from "gitlab-ai-provider"
910
import { ProviderTransform } from "@/provider/transform"
@@ -20,12 +21,12 @@ import { Bus } from "@/bus"
2021
import { Wildcard } from "@/util/wildcard"
2122
import { SessionID } from "@/session/schema"
2223
import { Auth } from "@/auth"
23-
import { Installation } from "@/installation"
2424
import { InstallationVersion } from "@opencode-ai/core/installation/version"
2525
import { EffectBridge } from "@/effect/bridge"
2626
import * as Option from "effect/Option"
2727
import * as OtelTracer from "@effect/opentelemetry/Tracer"
2828
import { LLMAISDK } from "./llm-ai-sdk"
29+
import { LLMNative } from "./llm-native"
2930

3031
const log = Log.create({ service: "llm" })
3132
export const OUTPUT_TOKEN_MAX = ProviderTransform.OUTPUT_TOKEN_MAX
@@ -34,6 +35,8 @@ export const OUTPUT_TOKEN_MAX = ProviderTransform.OUTPUT_TOKEN_MAX
3435
const mergeOptions = (target: Record<string, any>, source: Record<string, any> | undefined): Record<string, any> =>
3536
mergeDeep(target, source ?? {}) as Record<string, any>
3637

38+
const runtime = () => (process.env.OPENCODE_LLM_RUNTIME === "native" ? "native" : "ai-sdk")
39+
3740
export type StreamInput = {
3841
user: MessageV2.User
3942
sessionID: string
@@ -333,86 +336,123 @@ const live: Layer.Layer<
333336
? (yield* InstanceState.context).project.id
334337
: undefined
335338

336-
return streamText({
337-
onError(error) {
338-
l.error("stream error", {
339-
error,
340-
})
341-
},
342-
async experimental_repairToolCall(failed) {
343-
const lower = failed.toolCall.toolName.toLowerCase()
344-
if (lower !== failed.toolCall.toolName && sortedTools[lower]) {
345-
l.info("repairing tool call", {
346-
tool: failed.toolCall.toolName,
347-
repaired: lower,
348-
})
349-
return {
350-
...failed.toolCall,
351-
toolName: lower,
339+
const requestHeaders = {
340+
...(input.model.providerID.startsWith("opencode")
341+
? {
342+
...(opencodeProjectID ? { "x-opencode-project": opencodeProjectID } : {}),
343+
"x-opencode-session": input.sessionID,
344+
"x-opencode-request": input.user.id,
345+
"x-opencode-client": Flag.OPENCODE_CLIENT,
346+
"User-Agent": `opencode/${InstallationVersion}`,
352347
}
353-
}
354-
return {
355-
...failed.toolCall,
356-
input: JSON.stringify({
357-
tool: failed.toolCall.toolName,
358-
error: failed.error.message,
348+
: {
349+
"x-session-affinity": input.sessionID,
350+
...(input.parentSessionID ? { "x-parent-session-id": input.parentSessionID } : {}),
351+
"User-Agent": `opencode/${InstallationVersion}`,
359352
}),
360-
toolName: "invalid",
361-
}
362-
},
363-
temperature: params.temperature,
364-
topP: params.topP,
365-
topK: params.topK,
366-
providerOptions: ProviderTransform.providerOptions(input.model, params.options),
367-
activeTools: Object.keys(sortedTools).filter((x) => x !== "invalid"),
368-
tools: sortedTools,
369-
toolChoice: input.toolChoice,
370-
maxOutputTokens: params.maxOutputTokens,
371-
abortSignal: input.abort,
372-
headers: {
373-
...(input.model.providerID.startsWith("opencode")
374-
? {
375-
"x-opencode-project": opencodeProjectID,
376-
"x-opencode-session": input.sessionID,
377-
"x-opencode-request": input.user.id,
378-
"x-opencode-client": Flag.OPENCODE_CLIENT,
379-
"User-Agent": `opencode/${InstallationVersion}`,
353+
...input.model.headers,
354+
...headers,
355+
}
356+
357+
if (runtime() === "native") {
358+
if (input.model.providerID !== "openai" || input.model.api.npm !== "@ai-sdk/openai") {
359+
return yield* Effect.fail(new Error("Native LLM runtime currently only supports OpenAI models"))
360+
}
361+
if (Object.keys(sortedTools).length > 0) {
362+
return yield* Effect.fail(new Error("Native LLM runtime does not support tools yet"))
363+
}
364+
const apiKey =
365+
info?.type === "api" ? info.key : typeof item.options.apiKey === "string" ? item.options.apiKey : undefined
366+
if (!apiKey) return yield* Effect.fail(new Error("Native LLM runtime requires API key auth for OpenAI"))
367+
const baseURL = typeof item.options.baseURL === "string" ? item.options.baseURL : undefined
368+
return {
369+
type: "native" as const,
370+
stream: LLMClient.stream(
371+
LLMNative.request({
372+
model: input.model,
373+
apiKey,
374+
baseURL,
375+
system: isOpenaiOauth ? system : [],
376+
messages: ProviderTransform.message(messages, input.model, options),
377+
toolChoice: input.toolChoice,
378+
temperature: params.temperature,
379+
topP: params.topP,
380+
topK: params.topK,
381+
maxOutputTokens: params.maxOutputTokens,
382+
providerOptions: ProviderTransform.providerOptions(input.model, params.options),
383+
headers: requestHeaders,
384+
}),
385+
).pipe(Stream.provide(LLMClient.layer), Stream.provide(RequestExecutor.defaultLayer)),
386+
}
387+
}
388+
389+
return {
390+
type: "ai-sdk" as const,
391+
result: streamText({
392+
onError(error) {
393+
l.error("stream error", {
394+
error,
395+
})
396+
},
397+
async experimental_repairToolCall(failed) {
398+
const lower = failed.toolCall.toolName.toLowerCase()
399+
if (lower !== failed.toolCall.toolName && sortedTools[lower]) {
400+
l.info("repairing tool call", {
401+
tool: failed.toolCall.toolName,
402+
repaired: lower,
403+
})
404+
return {
405+
...failed.toolCall,
406+
toolName: lower,
380407
}
381-
: {
382-
"x-session-affinity": input.sessionID,
383-
...(input.parentSessionID ? { "x-parent-session-id": input.parentSessionID } : {}),
384-
"User-Agent": `opencode/${InstallationVersion}`,
408+
}
409+
return {
410+
...failed.toolCall,
411+
input: JSON.stringify({
412+
tool: failed.toolCall.toolName,
413+
error: failed.error.message,
385414
}),
386-
...input.model.headers,
387-
...headers,
388-
},
389-
maxRetries: input.retries ?? 0,
390-
messages,
391-
model: wrapLanguageModel({
392-
model: language,
393-
middleware: [
394-
{
395-
specificationVersion: "v3" as const,
396-
async transformParams(args) {
397-
if (args.type === "stream") {
398-
// @ts-expect-error
399-
args.params.prompt = ProviderTransform.message(args.params.prompt, input.model, options)
400-
}
401-
return args.params
415+
toolName: "invalid",
416+
}
417+
},
418+
temperature: params.temperature,
419+
topP: params.topP,
420+
topK: params.topK,
421+
providerOptions: ProviderTransform.providerOptions(input.model, params.options),
422+
activeTools: Object.keys(sortedTools).filter((x) => x !== "invalid"),
423+
tools: sortedTools,
424+
toolChoice: input.toolChoice,
425+
maxOutputTokens: params.maxOutputTokens,
426+
abortSignal: input.abort,
427+
headers: requestHeaders,
428+
maxRetries: input.retries ?? 0,
429+
messages,
430+
model: wrapLanguageModel({
431+
model: language,
432+
middleware: [
433+
{
434+
specificationVersion: "v3" as const,
435+
async transformParams(args) {
436+
if (args.type === "stream") {
437+
// @ts-expect-error
438+
args.params.prompt = ProviderTransform.message(args.params.prompt, input.model, options)
439+
}
440+
return args.params
441+
},
402442
},
443+
],
444+
}),
445+
experimental_telemetry: {
446+
isEnabled: cfg.experimental?.openTelemetry,
447+
functionId: "session.llm",
448+
tracer: telemetryTracer,
449+
metadata: {
450+
userId: cfg.username ?? "unknown",
451+
sessionId: input.sessionID,
403452
},
404-
],
405-
}),
406-
experimental_telemetry: {
407-
isEnabled: cfg.experimental?.openTelemetry,
408-
functionId: "session.llm",
409-
tracer: telemetryTracer,
410-
metadata: {
411-
userId: cfg.username ?? "unknown",
412-
sessionId: input.sessionID,
413453
},
414-
},
415-
})
454+
}),
455+
}
416456
})
417457

418458
const stream: Interface["stream"] = (input) =>
@@ -426,8 +466,12 @@ const live: Layer.Layer<
426466

427467
const result = yield* run({ ...input, abort: ctrl.signal })
428468

469+
if (result.type === "native") return result.stream
470+
429471
const state = LLMAISDK.adapterState()
430-
return Stream.fromAsyncIterable(result.fullStream, (e) => (e instanceof Error ? e : new Error(String(e)))).pipe(
472+
return Stream.fromAsyncIterable(result.result.fullStream, (e) =>
473+
e instanceof Error ? e : new Error(String(e)),
474+
).pipe(
431475
Stream.mapEffect((event) => LLMAISDK.toLLMEvents(state, event)),
432476
Stream.flatMap((events) => Stream.fromIterable(events)),
433477
)

0 commit comments

Comments
 (0)