Skip to content

Commit abf7836

Browse files
fix(providers): adapt assistant history for pi-ai (#17)
Signed-off-by: Sun-sunshine06 <Sun-sunshine06@users.noreply.github.com> Co-authored-by: Sun-sunshine06 <Sun-sunshine06@users.noreply.github.com>
1 parent 3af16c2 commit abf7836

2 files changed

Lines changed: 216 additions & 11 deletions

File tree

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import type { ChatMessage, ModelRef } from '@open-codesign/shared';
2+
import { afterEach, describe, expect, it, vi } from 'vitest';
3+
4+
const getModelMock = vi.fn();
5+
const completeSimpleMock = vi.fn();
6+
7+
vi.mock('@mariozechner/pi-ai', () => ({
8+
getModel: (...args: unknown[]) => getModelMock(...args),
9+
completeSimple: (...args: unknown[]) => completeSimpleMock(...args),
10+
}));
11+
12+
import { complete } from './index';
13+
14+
const MODEL: ModelRef = { provider: 'openai', modelId: 'gpt-4o' };
15+
16+
afterEach(() => {
17+
getModelMock.mockReset();
18+
completeSimpleMock.mockReset();
19+
});
20+
21+
describe('complete', () => {
22+
it('adapts shared chat history into pi-ai context for follow-up turns', async () => {
23+
getModelMock.mockReturnValue({
24+
id: 'gpt-4o',
25+
api: 'openai-completions',
26+
provider: 'openai',
27+
});
28+
completeSimpleMock.mockImplementationOnce(async (_model, context) => {
29+
expect(context.systemPrompt).toBe('You are open-codesign.');
30+
expect(context.messages).toEqual([
31+
{
32+
role: 'user',
33+
content: '介绍一下你自己',
34+
timestamp: 2,
35+
},
36+
{
37+
role: 'assistant',
38+
content: [{ type: 'text', text: '我是一个设计助手。' }],
39+
api: 'openai-completions',
40+
provider: 'openai',
41+
model: 'gpt-4o',
42+
usage: {
43+
input: 0,
44+
output: 0,
45+
cacheRead: 0,
46+
cacheWrite: 0,
47+
totalTokens: 0,
48+
cost: {
49+
input: 0,
50+
output: 0,
51+
cacheRead: 0,
52+
cacheWrite: 0,
53+
total: 0,
54+
},
55+
},
56+
stopReason: 'stop',
57+
timestamp: 3,
58+
},
59+
{
60+
role: 'user',
61+
content: '你可以干什么',
62+
timestamp: 4,
63+
},
64+
]);
65+
66+
return {
67+
role: 'assistant',
68+
content: [{ type: 'text', text: '我可以帮你生成设计稿。' }],
69+
api: 'openai-completions',
70+
provider: 'openai',
71+
model: 'gpt-4o',
72+
usage: {
73+
input: 12,
74+
output: 34,
75+
cacheRead: 0,
76+
cacheWrite: 0,
77+
totalTokens: 46,
78+
cost: {
79+
input: 0,
80+
output: 0,
81+
cacheRead: 0,
82+
cacheWrite: 0,
83+
total: 0.01,
84+
},
85+
},
86+
stopReason: 'stop',
87+
timestamp: Date.now(),
88+
};
89+
});
90+
91+
const messages: ChatMessage[] = [
92+
{ role: 'system', content: 'You are open-codesign.' },
93+
{ role: 'user', content: '介绍一下你自己' },
94+
{ role: 'assistant', content: '我是一个设计助手。' },
95+
{ role: 'user', content: '你可以干什么' },
96+
];
97+
98+
const result = await complete(MODEL, messages, { apiKey: 'sk-test' });
99+
100+
expect(result).toEqual({
101+
content: '我可以帮你生成设计稿。',
102+
inputTokens: 12,
103+
outputTokens: 34,
104+
costUsd: 0.01,
105+
});
106+
});
107+
});

packages/providers/src/index.ts

Lines changed: 109 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/**
22
* Wrappers around @mariozechner/pi-ai that fill capability gaps documented
33
* in docs/research/05-pi-ai-boundary.md. App code MUST go through this
4-
* package never import a provider SDK directly.
4+
* package - never import a provider SDK directly.
55
*
66
* Tier 1 implementations: minimum viable. Tier 2 features tracked separately.
77
*/
@@ -21,6 +21,70 @@ export interface GenerateResult {
2121
costUsd: number;
2222
}
2323

24+
interface PiTextContent {
25+
type: 'text';
26+
text: string;
27+
}
28+
29+
interface PiUsage {
30+
input: number;
31+
output: number;
32+
cacheRead: number;
33+
cacheWrite: number;
34+
totalTokens: number;
35+
cost: {
36+
input: number;
37+
output: number;
38+
cacheRead: number;
39+
cacheWrite: number;
40+
total: number;
41+
};
42+
}
43+
44+
interface PiUserMessage {
45+
role: 'user';
46+
content: string | PiTextContent[];
47+
timestamp: number;
48+
}
49+
50+
interface PiAssistantMessage {
51+
role: 'assistant';
52+
content: Array<{ type: string; text?: string }>;
53+
api: string;
54+
provider: string;
55+
model: string;
56+
usage: PiUsage;
57+
stopReason: 'stop' | 'length' | 'toolUse' | 'error' | 'aborted';
58+
errorMessage?: string;
59+
timestamp: number;
60+
}
61+
62+
interface PiContext {
63+
systemPrompt?: string;
64+
messages: Array<PiUserMessage | PiAssistantMessage>;
65+
}
66+
67+
interface PiModel {
68+
id: string;
69+
api: string;
70+
provider: string;
71+
}
72+
73+
const EMPTY_USAGE: PiUsage = {
74+
input: 0,
75+
output: 0,
76+
cacheRead: 0,
77+
cacheWrite: 0,
78+
totalTokens: 0,
79+
cost: {
80+
input: 0,
81+
output: 0,
82+
cacheRead: 0,
83+
cacheWrite: 0,
84+
total: 0,
85+
},
86+
};
87+
2488
/**
2589
* Single non-streaming completion. Tier 1: thin shim, no caching, no retry.
2690
* Tier 2 will swap to pi-ai's streaming API and emit ArtifactEvents directly.
@@ -37,17 +101,12 @@ export async function complete(
37101
}
38102

39103
const pi = (await import('@mariozechner/pi-ai')) as unknown as {
40-
getModel: (provider: string, modelId: string) => unknown;
104+
getModel: (provider: string, modelId: string) => PiModel | undefined;
41105
completeSimple: (
42-
model: unknown,
43-
context: { messages: ChatMessage[] },
106+
model: PiModel,
107+
context: PiContext,
44108
opts: { apiKey: string; baseUrl?: string; signal?: AbortSignal },
45-
) => Promise<{
46-
stopReason?: string;
47-
errorMessage?: string;
48-
content: Array<{ type: string; text?: string }>;
49-
usage?: { input?: number; output?: number; cost?: { total?: number } };
50-
}>;
109+
) => Promise<PiAssistantMessage>;
51110
};
52111

53112
const piModel = pi.getModel(model.provider, model.modelId);
@@ -64,7 +123,7 @@ export async function complete(
64123
if (opts.baseUrl !== undefined) piOpts.baseUrl = opts.baseUrl;
65124
if (opts.signal !== undefined) piOpts.signal = opts.signal;
66125

67-
const result = await pi.completeSimple(piModel, { messages }, piOpts);
126+
const result = await pi.completeSimple(piModel, toPiContext(messages, piModel), piOpts);
68127

69128
if (result.stopReason === 'error') {
70129
throw new CodesignError(result.errorMessage ?? 'Provider returned an error', 'PROVIDER_ERROR');
@@ -83,6 +142,45 @@ export async function complete(
83142
};
84143
}
85144

145+
function toPiContext(messages: ChatMessage[], model: PiModel): PiContext {
146+
const systemPrompt = messages
147+
.filter((message) => message.role === 'system')
148+
.map((message) => message.content.trim())
149+
.filter((content) => content.length > 0)
150+
.join('\n\n');
151+
152+
return {
153+
...(systemPrompt.length > 0 ? { systemPrompt } : {}),
154+
messages: messages.flatMap((message, index) => {
155+
const timestamp = index + 1;
156+
157+
if (message.role === 'system') {
158+
return [];
159+
}
160+
161+
if (message.role === 'user') {
162+
return {
163+
role: 'user',
164+
content: message.content,
165+
timestamp,
166+
};
167+
}
168+
169+
return {
170+
role: 'assistant',
171+
content:
172+
message.content.trim().length === 0 ? [] : [{ type: 'text', text: message.content }],
173+
api: model.api,
174+
provider: model.provider,
175+
model: model.id,
176+
usage: EMPTY_USAGE,
177+
stopReason: 'stop',
178+
timestamp,
179+
};
180+
}),
181+
};
182+
}
183+
86184
/**
87185
* Detect API provider from a pasted key prefix. Used by the onboarding flow
88186
* to spare the user from picking a provider manually.

0 commit comments

Comments
 (0)