|
1 | 1 | import type { WeightedOneOfOption } from "./schema"; |
2 | 2 | import { isWeightedOption } from "./schema"; |
| 3 | +import { getCurrentMessageContext } from "./message-context"; |
| 4 | +import { hasValueBeenUsed, markValueAsUsed } from "./unique-selection"; |
3 | 5 | import { random } from "./utils"; |
4 | 6 |
|
| 7 | +type PrimitiveKey = string | number | boolean; |
| 8 | +type UniqueIdentifierResolver<T> = |
| 9 | + | keyof any |
| 10 | + | ((value: T) => PrimitiveKey | null | undefined); |
| 11 | + |
| 12 | +export interface OneOfUniqueBy<T> { |
| 13 | + collection: string; |
| 14 | + itemId?: UniqueIdentifierResolver<T>; |
| 15 | +} |
| 16 | + |
| 17 | +export interface OneOfOptions<T> { |
| 18 | + unique?: boolean; |
| 19 | + uniqueBy?: OneOfUniqueBy<T>; |
| 20 | +} |
| 21 | + |
| 22 | +type NormalizedWeightedOption<T> = { |
| 23 | + value: T; |
| 24 | + weight?: number; |
| 25 | + uniqueKey?: string; |
| 26 | +}; |
| 27 | + |
| 28 | +type ResolvedUniqueBy<T> = { |
| 29 | + collection: string; |
| 30 | + itemId: UniqueIdentifierResolver<T>; |
| 31 | +}; |
| 32 | + |
5 | 33 | export function optional<T>(message: T): T | (() => []) { |
6 | 34 | return random() < 0.5 ? message : () => []; |
7 | 35 | } |
@@ -31,13 +59,27 @@ export function randomSample<T>(n: number, array: T[]): T[] { |
31 | 59 |
|
32 | 60 | return result; |
33 | 61 | } |
34 | | -export function oneOf<T>(options: Array<WeightedOneOfOption<T>>): T { |
| 62 | +export function oneOf<T>( |
| 63 | + options: Array<WeightedOneOfOption<T>>, |
| 64 | + config?: OneOfOptions<T> |
| 65 | +): T { |
35 | 66 | if (options.length === 0) { |
36 | 67 | throw new Error("oneOf requires at least one option"); |
37 | 68 | } |
38 | 69 |
|
| 70 | + const rawUniqueBy = config?.uniqueBy; |
| 71 | + const enforceUnique = Boolean(config?.unique ?? rawUniqueBy); |
| 72 | + |
| 73 | + if (enforceUnique && !rawUniqueBy) { |
| 74 | + throw new Error( |
| 75 | + "oneOf unique mode requires a uniqueBy option with a collection name" |
| 76 | + ); |
| 77 | + } |
| 78 | + |
| 79 | + const uniqueBy = resolveUniqueBy(rawUniqueBy); |
| 80 | + |
39 | 81 | const normalized = options.map((option) => |
40 | | - isWeightedOption(option) ? option : { value: option } |
| 82 | + isWeightedOption(option) ? { ...option } : { value: option } |
41 | 83 | ); |
42 | 84 |
|
43 | 85 | let providedWeightTotal = 0; |
@@ -81,34 +123,145 @@ export function oneOf<T>(options: Array<WeightedOneOfOption<T>>): T { |
81 | 123 | } |
82 | 124 | } |
83 | 125 |
|
84 | | - const totalWeight = normalized.reduce( |
| 126 | + const candidateBase = normalized.map<NormalizedWeightedOption<T>>((option) => ({ |
| 127 | + value: option.value, |
| 128 | + weight: option.weight, |
| 129 | + })); |
| 130 | + |
| 131 | + let candidateOptions = candidateBase; |
| 132 | + |
| 133 | + if (enforceUnique && uniqueBy) { |
| 134 | + candidateOptions = candidateBase |
| 135 | + .map((option) => ({ |
| 136 | + ...option, |
| 137 | + uniqueKey: buildUniqueKey(option.value, uniqueBy), |
| 138 | + })) |
| 139 | + .filter( |
| 140 | + (option) => !hasValueBeenUsed(uniqueBy.collection, option.uniqueKey!) |
| 141 | + ); |
| 142 | + |
| 143 | + if (candidateOptions.length === 0) { |
| 144 | + throw new Error( |
| 145 | + `oneOf uniqueBy collection "${uniqueBy.collection}" is exhausted` |
| 146 | + ); |
| 147 | + } |
| 148 | + } |
| 149 | + |
| 150 | + const totalWeight = candidateOptions.reduce( |
85 | 151 | (sum, option) => sum + (option.weight ?? 0), |
86 | 152 | 0 |
87 | 153 | ); |
88 | 154 |
|
89 | 155 | if (totalWeight <= 0) { |
90 | | - const fallback = normalized[Math.floor(random() * normalized.length)]; |
| 156 | + const fallback = |
| 157 | + candidateOptions[Math.floor(random() * candidateOptions.length)]; |
91 | 158 | if (!fallback) { |
92 | 159 | throw new Error("oneOf failed to select a fallback option"); |
93 | 160 | } |
| 161 | + recordUniqueSelection(fallback, uniqueBy, enforceUnique); |
94 | 162 | return fallback.value; |
95 | 163 | } |
96 | 164 |
|
97 | 165 | const needle = random() * totalWeight; |
98 | 166 | let cumulative = 0; |
99 | 167 |
|
100 | | - for (const option of normalized) { |
| 168 | + for (const option of candidateOptions) { |
101 | 169 | cumulative += option.weight ?? 0; |
102 | 170 |
|
103 | 171 | if (needle <= cumulative) { |
| 172 | + recordUniqueSelection(option, uniqueBy, enforceUnique); |
104 | 173 | return option.value; |
105 | 174 | } |
106 | 175 | } |
107 | 176 |
|
108 | | - const lastOption = normalized[normalized.length - 1]; |
| 177 | + const lastOption = candidateOptions[candidateOptions.length - 1]; |
109 | 178 | if (!lastOption) { |
110 | 179 | throw new Error("oneOf failed to resolve a selection"); |
111 | 180 | } |
112 | 181 |
|
| 182 | + recordUniqueSelection(lastOption, uniqueBy, enforceUnique); |
113 | 183 | return lastOption.value; |
114 | 184 | } |
| 185 | + |
| 186 | +function recordUniqueSelection<T>( |
| 187 | + option: NormalizedWeightedOption<T>, |
| 188 | + uniqueBy: ResolvedUniqueBy<T> | undefined, |
| 189 | + enforceUnique: boolean |
| 190 | +): void { |
| 191 | + if (!enforceUnique || !uniqueBy || !option.uniqueKey) { |
| 192 | + return; |
| 193 | + } |
| 194 | + |
| 195 | + const phase = getCurrentMessageContext()?.phase ?? "generate"; |
| 196 | + if (phase !== "generate") { |
| 197 | + return; |
| 198 | + } |
| 199 | + |
| 200 | + markValueAsUsed(uniqueBy.collection, option.uniqueKey); |
| 201 | +} |
| 202 | + |
| 203 | +function buildUniqueKey<T>( |
| 204 | + value: T, |
| 205 | + uniqueBy: ResolvedUniqueBy<T> |
| 206 | +): string { |
| 207 | + const rawValue = |
| 208 | + typeof uniqueBy.itemId === "function" |
| 209 | + ? uniqueBy.itemId(value) |
| 210 | + : readProperty(value, uniqueBy.itemId); |
| 211 | + |
| 212 | + if ( |
| 213 | + rawValue === undefined || |
| 214 | + rawValue === null || |
| 215 | + (typeof rawValue !== "string" && |
| 216 | + typeof rawValue !== "number" && |
| 217 | + typeof rawValue !== "boolean") |
| 218 | + ) { |
| 219 | + throw new Error( |
| 220 | + `oneOf uniqueBy.itemId "${String( |
| 221 | + uniqueBy.itemId |
| 222 | + )}" must resolve to a string, number, or boolean` |
| 223 | + ); |
| 224 | + } |
| 225 | + |
| 226 | + const prefix = typeof rawValue; |
| 227 | + return `${prefix}:${String(rawValue)}`; |
| 228 | +} |
| 229 | + |
| 230 | +function readProperty<T>(value: T, key: keyof any): unknown { |
| 231 | + if ( |
| 232 | + value === null || |
| 233 | + (typeof value !== "object" && typeof value !== "function") |
| 234 | + ) { |
| 235 | + throw new Error( |
| 236 | + `oneOf uniqueBy.itemId "${String( |
| 237 | + key |
| 238 | + )}" requires the option value to be an object or function` |
| 239 | + ); |
| 240 | + } |
| 241 | + |
| 242 | + return (value as any)[key]; |
| 243 | +} |
| 244 | + |
| 245 | +function resolveUniqueBy<T>( |
| 246 | + uniqueBy?: OneOfUniqueBy<T> |
| 247 | +): ResolvedUniqueBy<T> | undefined { |
| 248 | + if (!uniqueBy) { |
| 249 | + return undefined; |
| 250 | + } |
| 251 | + |
| 252 | + if (typeof uniqueBy.collection !== "string") { |
| 253 | + throw new Error("oneOf uniqueBy.collection must be a non-empty string"); |
| 254 | + } |
| 255 | + |
| 256 | + const collection = uniqueBy.collection.trim(); |
| 257 | + if (!collection) { |
| 258 | + throw new Error("oneOf uniqueBy.collection must be a non-empty string"); |
| 259 | + } |
| 260 | + |
| 261 | + const itemId = uniqueBy.itemId ?? ("id" as UniqueIdentifierResolver<T>); |
| 262 | + |
| 263 | + return { |
| 264 | + collection, |
| 265 | + itemId, |
| 266 | + }; |
| 267 | +} |
0 commit comments