Skip to content

Commit 6741b3c

Browse files
author
Michal Warda
committed
WIP: dataset unique across different runs
1 parent d04f908 commit 6741b3c

8 files changed

Lines changed: 403 additions & 10 deletions

File tree

README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,29 @@ const schema = () => [
221221

222222
`oneOf` accepts plain schema entries or `{ value, weight }` objects. Provide any subset of weights (summing to ≤ 1) and the remaining probability is spread evenly across unweighted entries.
223223

224+
#### Unique draws across a dataset
225+
226+
Pass a `uniqueBy` configuration when you need each option to be used at most once across every row/schema during generation:
227+
228+
```ts
229+
const toolOptions = [
230+
weatherTool.toolFunction(),
231+
calendarTool.toolFunction(),
232+
flightTool.toolFunction(),
233+
] as const;
234+
235+
const schema = () => [
236+
oneOf(toolOptions, {
237+
uniqueBy: {
238+
collection: "tools",
239+
itemId: "name",
240+
},
241+
}),
242+
];
243+
```
244+
245+
The `collection` name identifies the shared pool (so multiple `oneOf` calls can coordinate), and `itemId` can be either a property key or a function that returns a stable identifier. Omit `itemId` to default to the common `id` field. Torque throws if the pool is exhausted, making it easy to guarantee perfect round-robin coverage.
246+
224247
> 💡 See weighted example: [`examples/weighted-one-of.ts`](examples/weighted-one-of.ts)
225248
> 💡 Full utilities demo: [`examples/composition-utilities.ts`](examples/composition-utilities.ts) | [▶️ Try in Browser](https://stackblitz.com/github/qforge-dev/torque/tree/main/stackblitz-templates/composition-utilities)
226249

bun.lock

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/torque/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,29 @@ const schema = () => [
230230

231231
`oneOf` accepts plain schema entries or `{ value, weight }` objects. Provide any subset of weights (summing to ≤ 1) and the remaining probability is spread evenly across unweighted entries.
232232

233+
#### Unique draws across a dataset
234+
235+
Pass a `uniqueBy` configuration when you need each option to be used at most once across every row/schema during generation:
236+
237+
```ts
238+
const toolOptions = [
239+
weatherTool.toolFunction(),
240+
calendarTool.toolFunction(),
241+
flightTool.toolFunction(),
242+
] as const;
243+
244+
const schema = () => [
245+
oneOf(toolOptions, {
246+
uniqueBy: {
247+
collection: "tools",
248+
itemId: "name",
249+
},
250+
}),
251+
];
252+
```
253+
254+
The `collection` name identifies the shared pool (so multiple `oneOf` calls can coordinate), and `itemId` can be either a property key or a function that returns a stable identifier. Omit `itemId` to default to the common `id` field. Torque throws if the pool is exhausted, making it easy to guarantee perfect round-robin coverage.
255+
233256
> 💡 See weighted example: [`examples/weighted-one-of.ts`](examples/weighted-one-of.ts)
234257
> 💡 Full utilities demo: [`examples/composition-utilities.ts`](examples/composition-utilities.ts) | [▶️ Try in Browser](https://stackblitz.com/github/qforge-dev/torque/tree/main/stackblitz-templates/composition-utilities)
235258

packages/torque/src/dataset.ts

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import { createWriter } from "./writer";
3232
import { createFormatter } from "./formatter";
3333
import { TokenCounterPool } from "./token-counting/tokenCounterPool";
3434
import { hoistSystemMessages } from "./ai-message-order";
35+
import { runWithMessageContext } from "./message-context";
36+
import { runWithUniqueSelectionScope } from "./unique-selection";
3537

3638
const DEFAULT_TOKEN_COUNTER_WORKERS = 3;
3739

@@ -417,7 +419,9 @@ async function checkMessageSchemaStructure(
417419
generationContext,
418420
};
419421

420-
const message = await messageFactory(checkContext);
422+
const message = await runWithMessageContext(checkContext, () =>
423+
messageFactory(checkContext)
424+
);
421425
if (message === null) return structure;
422426

423427
if (Array.isArray(message)) {
@@ -458,7 +462,9 @@ async function checkMessageSchemaStructure(
458462
generationId: string;
459463
}[] = [];
460464
for (const tc of message.toolCalls) {
461-
const toolCall = await tc(checkContext);
465+
const toolCall = await runWithMessageContext(checkContext, () =>
466+
tc(checkContext)
467+
);
462468
toolCallStructures.push({
463469
toolCallId: toolCall.toolCallId,
464470
toolName: toolCall.toolName,
@@ -609,7 +615,9 @@ async function convertMessageSchemaToDatasetMessage(
609615
generationContext,
610616
};
611617

612-
const message = await messageFactory(context);
618+
const message = await runWithMessageContext(context, () =>
619+
messageFactory(context)
620+
);
613621

614622
if (message === null) return acc;
615623
if (Array.isArray(message)) {
@@ -733,7 +741,9 @@ async function convertMessageSchemaToDatasetMessage(
733741
if (message.toolCalls && message.toolCalls.length > 0) {
734742
const toolCallParts: IToolCallSchema<any>[] = [];
735743
for (const tc of message.toolCalls) {
736-
const toolCall = await tc(context);
744+
const toolCall = await runWithMessageContext(context, () =>
745+
tc(context)
746+
);
737747
toolCallParts.push(toolCall);
738748
}
739749

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { AsyncLocalStorage } from "async_hooks";
2+
import type { IMessageSchemaContext } from "./types";
3+
4+
type MaybePromise<T> = T | Promise<T>;
5+
6+
const messageContextStorage = new AsyncLocalStorage<IMessageSchemaContext>();
7+
8+
export function runWithMessageContext<T>(
9+
context: IMessageSchemaContext,
10+
fn: () => MaybePromise<T>
11+
): MaybePromise<T> {
12+
const currentStore = messageContextStorage.getStore();
13+
if (currentStore === context) {
14+
return fn();
15+
}
16+
17+
return messageContextStorage.run(context, fn);
18+
}
19+
20+
export function getCurrentMessageContext(): IMessageSchemaContext | undefined {
21+
return messageContextStorage.getStore();
22+
}

packages/torque/src/schema-rng.ts

Lines changed: 159 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,35 @@
11
import type { WeightedOneOfOption } from "./schema";
22
import { isWeightedOption } from "./schema";
3+
import { getCurrentMessageContext } from "./message-context";
4+
import { hasValueBeenUsed, markValueAsUsed } from "./unique-selection";
35
import { random } from "./utils";
46

7+
type PrimitiveKey = string | number | boolean;
8+
type UniqueIdentifierResolver<T> =
9+
| keyof any
10+
| ((value: T) => PrimitiveKey | null | undefined);
11+
12+
export interface OneOfUniqueBy<T> {
13+
collection: string;
14+
itemId?: UniqueIdentifierResolver<T>;
15+
}
16+
17+
export interface OneOfOptions<T> {
18+
unique?: boolean;
19+
uniqueBy?: OneOfUniqueBy<T>;
20+
}
21+
22+
type NormalizedWeightedOption<T> = {
23+
value: T;
24+
weight?: number;
25+
uniqueKey?: string;
26+
};
27+
28+
type ResolvedUniqueBy<T> = {
29+
collection: string;
30+
itemId: UniqueIdentifierResolver<T>;
31+
};
32+
533
export function optional<T>(message: T): T | (() => []) {
634
return random() < 0.5 ? message : () => [];
735
}
@@ -31,13 +59,27 @@ export function randomSample<T>(n: number, array: T[]): T[] {
3159

3260
return result;
3361
}
34-
export function oneOf<T>(options: Array<WeightedOneOfOption<T>>): T {
62+
export function oneOf<T>(
63+
options: Array<WeightedOneOfOption<T>>,
64+
config?: OneOfOptions<T>
65+
): T {
3566
if (options.length === 0) {
3667
throw new Error("oneOf requires at least one option");
3768
}
3869

70+
const rawUniqueBy = config?.uniqueBy;
71+
const enforceUnique = Boolean(config?.unique ?? rawUniqueBy);
72+
73+
if (enforceUnique && !rawUniqueBy) {
74+
throw new Error(
75+
"oneOf unique mode requires a uniqueBy option with a collection name"
76+
);
77+
}
78+
79+
const uniqueBy = resolveUniqueBy(rawUniqueBy);
80+
3981
const normalized = options.map((option) =>
40-
isWeightedOption(option) ? option : { value: option }
82+
isWeightedOption(option) ? { ...option } : { value: option }
4183
);
4284

4385
let providedWeightTotal = 0;
@@ -81,34 +123,145 @@ export function oneOf<T>(options: Array<WeightedOneOfOption<T>>): T {
81123
}
82124
}
83125

84-
const totalWeight = normalized.reduce(
126+
const candidateBase = normalized.map<NormalizedWeightedOption<T>>((option) => ({
127+
value: option.value,
128+
weight: option.weight,
129+
}));
130+
131+
let candidateOptions = candidateBase;
132+
133+
if (enforceUnique && uniqueBy) {
134+
candidateOptions = candidateBase
135+
.map((option) => ({
136+
...option,
137+
uniqueKey: buildUniqueKey(option.value, uniqueBy),
138+
}))
139+
.filter(
140+
(option) => !hasValueBeenUsed(uniqueBy.collection, option.uniqueKey!)
141+
);
142+
143+
if (candidateOptions.length === 0) {
144+
throw new Error(
145+
`oneOf uniqueBy collection "${uniqueBy.collection}" is exhausted`
146+
);
147+
}
148+
}
149+
150+
const totalWeight = candidateOptions.reduce(
85151
(sum, option) => sum + (option.weight ?? 0),
86152
0
87153
);
88154

89155
if (totalWeight <= 0) {
90-
const fallback = normalized[Math.floor(random() * normalized.length)];
156+
const fallback =
157+
candidateOptions[Math.floor(random() * candidateOptions.length)];
91158
if (!fallback) {
92159
throw new Error("oneOf failed to select a fallback option");
93160
}
161+
recordUniqueSelection(fallback, uniqueBy, enforceUnique);
94162
return fallback.value;
95163
}
96164

97165
const needle = random() * totalWeight;
98166
let cumulative = 0;
99167

100-
for (const option of normalized) {
168+
for (const option of candidateOptions) {
101169
cumulative += option.weight ?? 0;
102170

103171
if (needle <= cumulative) {
172+
recordUniqueSelection(option, uniqueBy, enforceUnique);
104173
return option.value;
105174
}
106175
}
107176

108-
const lastOption = normalized[normalized.length - 1];
177+
const lastOption = candidateOptions[candidateOptions.length - 1];
109178
if (!lastOption) {
110179
throw new Error("oneOf failed to resolve a selection");
111180
}
112181

182+
recordUniqueSelection(lastOption, uniqueBy, enforceUnique);
113183
return lastOption.value;
114184
}
185+
186+
function recordUniqueSelection<T>(
187+
option: NormalizedWeightedOption<T>,
188+
uniqueBy: ResolvedUniqueBy<T> | undefined,
189+
enforceUnique: boolean
190+
): void {
191+
if (!enforceUnique || !uniqueBy || !option.uniqueKey) {
192+
return;
193+
}
194+
195+
const phase = getCurrentMessageContext()?.phase ?? "generate";
196+
if (phase !== "generate") {
197+
return;
198+
}
199+
200+
markValueAsUsed(uniqueBy.collection, option.uniqueKey);
201+
}
202+
203+
function buildUniqueKey<T>(
204+
value: T,
205+
uniqueBy: ResolvedUniqueBy<T>
206+
): string {
207+
const rawValue =
208+
typeof uniqueBy.itemId === "function"
209+
? uniqueBy.itemId(value)
210+
: readProperty(value, uniqueBy.itemId);
211+
212+
if (
213+
rawValue === undefined ||
214+
rawValue === null ||
215+
(typeof rawValue !== "string" &&
216+
typeof rawValue !== "number" &&
217+
typeof rawValue !== "boolean")
218+
) {
219+
throw new Error(
220+
`oneOf uniqueBy.itemId "${String(
221+
uniqueBy.itemId
222+
)}" must resolve to a string, number, or boolean`
223+
);
224+
}
225+
226+
const prefix = typeof rawValue;
227+
return `${prefix}:${String(rawValue)}`;
228+
}
229+
230+
function readProperty<T>(value: T, key: keyof any): unknown {
231+
if (
232+
value === null ||
233+
(typeof value !== "object" && typeof value !== "function")
234+
) {
235+
throw new Error(
236+
`oneOf uniqueBy.itemId "${String(
237+
key
238+
)}" requires the option value to be an object or function`
239+
);
240+
}
241+
242+
return (value as any)[key];
243+
}
244+
245+
function resolveUniqueBy<T>(
246+
uniqueBy?: OneOfUniqueBy<T>
247+
): ResolvedUniqueBy<T> | undefined {
248+
if (!uniqueBy) {
249+
return undefined;
250+
}
251+
252+
if (typeof uniqueBy.collection !== "string") {
253+
throw new Error("oneOf uniqueBy.collection must be a non-empty string");
254+
}
255+
256+
const collection = uniqueBy.collection.trim();
257+
if (!collection) {
258+
throw new Error("oneOf uniqueBy.collection must be a non-empty string");
259+
}
260+
261+
const itemId = uniqueBy.itemId ?? ("id" as UniqueIdentifierResolver<T>);
262+
263+
return {
264+
collection,
265+
itemId,
266+
};
267+
}

0 commit comments

Comments
 (0)