Skip to content

Commit dc30514

Browse files
committed
fix(llm-catalog): harden sync tests and baseModelName catalog merge
- Use postgresTest + @internal/testcontainers for sync integration tests - Assert catalog null clears baseModelName (gemini-pro regression) - Exclude *.test.ts from package tsc; vitest timeout for containers - Keep conditional updateMany and undefined-only merges in sync.ts - Refresh pnpm-lock for testcontainers devDependency
1 parent 7ef732e commit dc30514

File tree

6 files changed

+212
-114
lines changed

6 files changed

+212
-114
lines changed

internal-packages/llm-model-catalog/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"@trigger.dev/database": "workspace:*"
1111
},
1212
"devDependencies": {
13+
"@internal/testcontainers": "workspace:*",
1314
"vitest": "3.1.4"
1415
},
1516
"scripts": {
Lines changed: 190 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,203 @@
1-
import { describe, it, expect, vi } from "vitest";
2-
import { syncLlmCatalog } from "./sync.js";
1+
import type { PrismaClient } from "@trigger.dev/database";
2+
import { postgresTest } from "@internal/testcontainers";
3+
import { generateFriendlyId } from "@trigger.dev/core/v3/isomorphic";
4+
import { describe, expect } from "vitest";
35
import { defaultModelPrices } from "./defaultPrices.js";
6+
import { modelCatalog } from "./modelCatalog.js";
7+
import { syncLlmCatalog } from "./sync.js";
48

5-
const gpt4oDef = defaultModelPrices.find((m) => m.modelName === "gpt-4o");
6-
if (!gpt4oDef) {
7-
throw new Error("expected gpt-4o in defaultModelPrices");
9+
function getGpt4oDefinition() {
10+
const def = defaultModelPrices.find((m) => m.modelName === "gpt-4o");
11+
if (def === undefined) {
12+
throw new Error("expected gpt-4o in defaultModelPrices");
13+
}
14+
return def;
815
}
916

10-
describe("syncLlmCatalog", () => {
11-
it("rebuilds pricing tiers and prices for existing default-source models", async () => {
12-
const existingId = "existing-gpt4o";
13-
14-
const llmModelUpdate = vi.fn();
15-
const llmPricingTierDeleteMany = vi.fn();
16-
const llmPricingTierCreate = vi.fn();
17-
18-
const prisma = {
19-
llmModel: {
20-
findFirst: vi.fn(async (args: { where: { modelName: string } }) => {
21-
if (args.where.modelName === "gpt-4o") {
22-
return {
23-
id: existingId,
24-
source: "default",
25-
provider: "openai",
26-
description: "stale description",
27-
contextWindow: 999,
28-
maxOutputTokens: 888,
29-
capabilities: ["legacy"],
30-
isHidden: true,
31-
baseModelName: "legacy-base",
32-
};
33-
}
34-
return null;
35-
}),
36-
},
37-
$transaction: vi.fn(async (fn: (tx: unknown) => Promise<void>) => {
38-
await fn({
39-
llmModel: { update: llmModelUpdate },
40-
llmPricingTier: {
41-
deleteMany: llmPricingTierDeleteMany,
42-
create: llmPricingTierCreate,
43-
},
44-
});
45-
}),
46-
};
47-
48-
await syncLlmCatalog(prisma as never);
49-
50-
expect(prisma.$transaction).toHaveBeenCalledTimes(1);
51-
52-
expect(llmModelUpdate).toHaveBeenCalledWith({
53-
where: { id: existingId },
54-
data: expect.objectContaining({
55-
matchPattern: gpt4oDef.matchPattern,
56-
startDate: gpt4oDef.startDate ? new Date(gpt4oDef.startDate) : null,
57-
}),
58-
});
59-
60-
expect(llmPricingTierDeleteMany).toHaveBeenCalledWith({
61-
where: { modelId: existingId },
62-
});
63-
64-
expect(llmPricingTierCreate).toHaveBeenCalledTimes(gpt4oDef.pricingTiers.length);
65-
66-
const firstTier = gpt4oDef.pricingTiers[0];
67-
expect(llmPricingTierCreate).toHaveBeenCalledWith({
68-
data: {
69-
modelId: existingId,
70-
name: firstTier.name,
71-
isDefault: firstTier.isDefault,
72-
priority: firstTier.priority,
73-
conditions: firstTier.conditions,
74-
prices: {
75-
create: expect.arrayContaining(
76-
Object.entries(firstTier.prices).map(([usageType, price]) => ({
77-
modelId: existingId,
78-
usageType,
79-
price,
80-
}))
81-
),
82-
},
17+
const gpt4oDef = getGpt4oDefinition();
18+
19+
function getGeminiProDefinition() {
20+
const def = defaultModelPrices.find((m) => m.modelName === "gemini-pro");
21+
if (def === undefined) {
22+
throw new Error("expected gemini-pro in defaultModelPrices");
23+
}
24+
return def;
25+
}
26+
27+
const geminiProDef = getGeminiProDefinition();
28+
29+
/** If sync used `catalog?.baseModelName ?? existing.baseModelName`, sync would keep this string instead of clearing to null. */
30+
const STALE_BASE_MODEL_NAME = "wrong-base-model-sentinel";
31+
32+
const STALE_INPUT_PRICE = 0.099;
33+
const STALE_OUTPUT_PRICE = 0.088;
34+
35+
async function createGpt4oWithStalePricing(
36+
prisma: PrismaClient,
37+
source: "default" | "admin"
38+
) {
39+
const model = await prisma.llmModel.create({
40+
data: {
41+
friendlyId: generateFriendlyId("llm_model"),
42+
projectId: null,
43+
modelName: gpt4oDef.modelName,
44+
matchPattern: "^stale-pattern$",
45+
startDate: gpt4oDef.startDate ? new Date(gpt4oDef.startDate) : null,
46+
source,
47+
provider: "stale-provider",
48+
description: "stale description",
49+
contextWindow: 111,
50+
maxOutputTokens: 222,
51+
capabilities: ["stale-cap"],
52+
isHidden: true,
53+
baseModelName: "stale-base",
54+
},
55+
});
56+
57+
await prisma.llmPricingTier.create({
58+
data: {
59+
modelId: model.id,
60+
name: "Standard",
61+
isDefault: true,
62+
priority: 0,
63+
conditions: [],
64+
prices: {
65+
create: [
66+
{ modelId: model.id, usageType: "input", price: STALE_INPUT_PRICE },
67+
{ modelId: model.id, usageType: "output", price: STALE_OUTPUT_PRICE },
68+
],
8369
},
84-
});
70+
},
71+
});
8572

86-
const createCall = llmPricingTierCreate.mock.calls[0][0] as {
87-
data: { prices: { create: { usageType: string; price: number; modelId: string }[] } };
88-
};
89-
expect(createCall.data.prices.create).toHaveLength(Object.keys(firstTier.prices).length);
73+
return model;
74+
}
75+
76+
async function createGeminiProWithStaleBaseModelName(prisma: PrismaClient) {
77+
const catalogEntry = modelCatalog[geminiProDef.modelName];
78+
expect(catalogEntry).toBeDefined();
79+
expect(catalogEntry.baseModelName).toBeNull();
80+
81+
const model = await prisma.llmModel.create({
82+
data: {
83+
friendlyId: generateFriendlyId("llm_model"),
84+
projectId: null,
85+
modelName: geminiProDef.modelName,
86+
matchPattern: "^stale-gemini-pattern$",
87+
startDate: geminiProDef.startDate ? new Date(geminiProDef.startDate) : null,
88+
source: "default",
89+
provider: "stale-provider",
90+
description: "stale description",
91+
contextWindow: 111,
92+
maxOutputTokens: 222,
93+
capabilities: ["stale-cap"],
94+
isHidden: true,
95+
baseModelName: STALE_BASE_MODEL_NAME,
96+
},
9097
});
9198

92-
it("does not rebuild pricing for non-default source models", async () => {
93-
const prisma = {
94-
llmModel: {
95-
findFirst: vi.fn(async (args: { where: { modelName: string } }) => {
96-
if (args.where.modelName === "gpt-4o") {
97-
return {
98-
id: "admin-edited",
99-
source: "admin",
100-
provider: null,
101-
description: null,
102-
contextWindow: null,
103-
maxOutputTokens: null,
104-
capabilities: [],
105-
isHidden: false,
106-
baseModelName: null,
107-
};
108-
}
109-
return null;
110-
}),
99+
const tier = geminiProDef.pricingTiers[0];
100+
await prisma.llmPricingTier.create({
101+
data: {
102+
modelId: model.id,
103+
name: tier.name,
104+
isDefault: tier.isDefault,
105+
priority: tier.priority,
106+
conditions: tier.conditions,
107+
prices: {
108+
create: Object.entries(tier.prices).map(([usageType, price]) => ({
109+
modelId: model.id,
110+
usageType,
111+
price,
112+
})),
111113
},
112-
$transaction: vi.fn(),
113-
};
114+
},
115+
});
114116

115-
const result = await syncLlmCatalog(prisma as never);
117+
return model;
118+
}
116119

117-
expect(prisma.$transaction).not.toHaveBeenCalled();
118-
expect(result.modelsUpdated).toBe(0);
119-
expect(result.modelsSkipped).toBeGreaterThan(0);
120+
async function loadGpt4oWithTiers(prisma: PrismaClient) {
121+
return prisma.llmModel.findFirst({
122+
where: { projectId: null, modelName: gpt4oDef.modelName },
123+
include: {
124+
pricingTiers: {
125+
include: { prices: true },
126+
orderBy: { priority: "asc" },
127+
},
128+
},
120129
});
130+
}
131+
132+
function expectBundledGpt4oPricing(model: NonNullable<Awaited<ReturnType<typeof loadGpt4oWithTiers>>>) {
133+
expect(model.matchPattern).toBe(gpt4oDef.matchPattern);
134+
expect(model.pricingTiers).toHaveLength(gpt4oDef.pricingTiers.length);
135+
136+
const dbTier = model.pricingTiers[0];
137+
const defTier = gpt4oDef.pricingTiers[0];
138+
expect(dbTier.name).toBe(defTier.name);
139+
expect(dbTier.isDefault).toBe(defTier.isDefault);
140+
expect(dbTier.priority).toBe(defTier.priority);
141+
142+
const priceByType = new Map(dbTier.prices.map((p) => [p.usageType, Number(p.price)]));
143+
for (const [usageType, expected] of Object.entries(defTier.prices)) {
144+
expect(priceByType.get(usageType)).toBeCloseTo(expected, 12);
145+
}
146+
expect(priceByType.size).toBe(Object.keys(defTier.prices).length);
147+
}
148+
149+
describe("syncLlmCatalog", () => {
150+
postgresTest(
151+
"rebuilds gpt-4o pricing tiers from bundled defaults when source is default",
152+
async ({ prisma }) => {
153+
await createGpt4oWithStalePricing(prisma, "default");
154+
155+
const result = await syncLlmCatalog(prisma);
156+
157+
expect(result.modelsUpdated).toBe(1);
158+
expect(result.modelsSkipped).toBe(defaultModelPrices.length - 1);
159+
160+
const after = await loadGpt4oWithTiers(prisma);
161+
expect(after).not.toBeNull();
162+
expectBundledGpt4oPricing(after!);
163+
}
164+
);
165+
166+
postgresTest(
167+
"does not replace pricing tiers when model source is not default",
168+
async ({ prisma }) => {
169+
await createGpt4oWithStalePricing(prisma, "admin");
170+
171+
const result = await syncLlmCatalog(prisma);
172+
173+
expect(result.modelsUpdated).toBe(0);
174+
expect(result.modelsSkipped).toBeGreaterThanOrEqual(1);
175+
176+
const after = await loadGpt4oWithTiers(prisma);
177+
expect(after).not.toBeNull();
178+
expect(after!.matchPattern).toBe("^stale-pattern$");
179+
expect(after!.pricingTiers).toHaveLength(1);
180+
const prices = after!.pricingTiers[0].prices;
181+
const input = prices.find((p) => p.usageType === "input");
182+
const output = prices.find((p) => p.usageType === "output");
183+
expect(Number(input?.price)).toBeCloseTo(STALE_INPUT_PRICE, 12);
184+
expect(Number(output?.price)).toBeCloseTo(STALE_OUTPUT_PRICE, 12);
185+
expect(prices).toHaveLength(2);
186+
}
187+
);
188+
189+
postgresTest(
190+
"clears baseModelName when bundled catalog has null (regression for nullish-coalescing merge)",
191+
async ({ prisma }) => {
192+
await createGeminiProWithStaleBaseModelName(prisma);
193+
194+
await syncLlmCatalog(prisma);
195+
196+
const after = await prisma.llmModel.findFirst({
197+
where: { projectId: null, modelName: geminiProDef.modelName },
198+
});
199+
expect(after).not.toBeNull();
200+
expect(after!.baseModelName).toBeNull();
201+
}
202+
);
121203
});

internal-packages/llm-model-catalog/src/sync.ts

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ export async function syncLlmCatalog(prisma: PrismaClient): Promise<{
5252

5353
const catalog = modelCatalog[modelDef.modelName];
5454

55-
await prisma.$transaction(async (tx) => {
56-
await tx.llmModel.update({
57-
where: { id: existing.id },
55+
const applied = await prisma.$transaction(async (tx) => {
56+
const updateResult = await tx.llmModel.updateMany({
57+
where: { id: existing.id, source: "default" },
5858
data: {
5959
// Update match pattern and start date from Langfuse (may have changed)
6060
matchPattern: modelDef.matchPattern,
@@ -70,20 +70,31 @@ export async function syncLlmCatalog(prisma: PrismaClient): Promise<{
7070
: catalog.maxOutputTokens,
7171
capabilities: catalog?.capabilities ?? existing.capabilities,
7272
isHidden: catalog?.isHidden ?? existing.isHidden,
73-
baseModelName: catalog?.baseModelName ?? existing.baseModelName,
73+
baseModelName:
74+
catalog?.baseModelName === undefined
75+
? existing.baseModelName
76+
: catalog.baseModelName,
7477
},
7578
});
7679

80+
if (updateResult.count !== 1) {
81+
return false;
82+
}
83+
7784
await tx.llmPricingTier.deleteMany({ where: { modelId: existing.id } });
7885

7986
for (const tier of modelDef.pricingTiers) {
8087
await tx.llmPricingTier.create({
8188
data: pricingTierCreateData(existing.id, tier),
8289
});
8390
}
91+
92+
return true;
8493
});
8594

86-
modelsUpdated++;
95+
if (applied) {
96+
modelsUpdated++;
97+
}
8798
}
8899

89100
return { modelsUpdated, modelsSkipped };

internal-packages/llm-model-catalog/tsconfig.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@
1515
"strict": true,
1616
"resolveJsonModule": true
1717
},
18-
"exclude": ["node_modules"]
18+
"exclude": ["node_modules", "**/*.test.ts"]
1919
}

internal-packages/llm-model-catalog/vitest.config.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ export default defineConfig({
1111
singleThread: true,
1212
},
1313
},
14+
testTimeout: 120_000,
1415
},
1516
});

pnpm-lock.yaml

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)