|
1 | | -import type { ModelProfile } from "./core/ModelProfile"; |
2 | | -import { |
3 | | - ModelProfileResolver, |
4 | | - type ModelProfileResolverOptions, |
5 | | - type ResolveModelProfileInput, |
6 | | -} from "./core/ModelProfileResolver"; |
7 | | - |
8 | | -export type QueryScope = "broad" | "normal" | "narrow" | "default"; |
9 | | - |
10 | | -export interface ProjectionHead { |
11 | | - dimIn: number; |
12 | | - dimOut: number; |
13 | | - bits?: number; |
14 | | - // Byte offset for the projection head in a flattened projection buffer. |
15 | | - offset: number; |
16 | | -} |
17 | | - |
18 | | -export interface RoutingPolicy { |
19 | | - broad: ProjectionHead; |
20 | | - normal: ProjectionHead; |
21 | | - narrow: ProjectionHead; |
22 | | -} |
23 | | - |
24 | | -export interface ResolvedRoutingPolicy { |
25 | | - modelProfile: ModelProfile; |
26 | | - routingPolicy: RoutingPolicy; |
27 | | -} |
28 | | - |
29 | | -export interface ResolveRoutingPolicyOptions { |
30 | | - resolver?: ModelProfileResolver; |
31 | | - resolverOptions?: ModelProfileResolverOptions; |
32 | | - routingPolicyOverrides?: Partial<RoutingPolicyDerivation>; |
33 | | -} |
34 | | - |
35 | | -export interface RoutingPolicyDerivation { |
36 | | - broadDimRatio: number; |
37 | | - normalDimRatio: number; |
38 | | - narrowDimRatio: number; |
39 | | - broadHashBits: number; |
40 | | - dimAlignment: number; |
41 | | - minProjectionDim: number; |
42 | | -} |
43 | | - |
44 | | -export const DEFAULT_ROUTING_POLICY_DERIVATION: RoutingPolicyDerivation = |
45 | | - Object.freeze({ |
46 | | - broadDimRatio: 1 / 8, |
47 | | - normalDimRatio: 1 / 4, |
48 | | - narrowDimRatio: 1 / 2, |
49 | | - broadHashBits: 128, |
50 | | - dimAlignment: 8, |
51 | | - minProjectionDim: 8, |
52 | | - }); |
53 | | - |
54 | | -function assertPositiveInteger(name: string, value: number): void { |
55 | | - if (!Number.isInteger(value) || value <= 0) { |
56 | | - throw new Error(`${name} must be a positive integer`); |
57 | | - } |
58 | | -} |
59 | | - |
60 | | -function assertPositiveFinite(name: string, value: number): void { |
61 | | - if (!Number.isFinite(value) || value <= 0) { |
62 | | - throw new Error(`${name} must be positive and finite`); |
63 | | - } |
64 | | -} |
65 | | - |
66 | | -function alignDown(value: number, alignment: number): number { |
67 | | - return Math.floor(value / alignment) * alignment; |
68 | | -} |
69 | | - |
70 | | -function deriveProjectionDim( |
71 | | - dimIn: number, |
72 | | - ratio: number, |
73 | | - derivation: RoutingPolicyDerivation, |
74 | | -): number { |
75 | | - const raw = Math.floor(dimIn * ratio); |
76 | | - const aligned = alignDown(raw, derivation.dimAlignment); |
77 | | - const bounded = Math.max(derivation.minProjectionDim, aligned); |
78 | | - return Math.min(dimIn, bounded); |
79 | | -} |
80 | | - |
81 | | -function validateDerivation(derivation: RoutingPolicyDerivation): void { |
82 | | - assertPositiveFinite("broadDimRatio", derivation.broadDimRatio); |
83 | | - assertPositiveFinite("normalDimRatio", derivation.normalDimRatio); |
84 | | - assertPositiveFinite("narrowDimRatio", derivation.narrowDimRatio); |
85 | | - assertPositiveInteger("broadHashBits", derivation.broadHashBits); |
86 | | - assertPositiveInteger("dimAlignment", derivation.dimAlignment); |
87 | | - assertPositiveInteger("minProjectionDim", derivation.minProjectionDim); |
88 | | -} |
89 | | - |
90 | | -export function createRoutingPolicy( |
91 | | - modelProfile: Pick<ModelProfile, "embeddingDimension">, |
92 | | - overrides: Partial<RoutingPolicyDerivation> = {}, |
93 | | -): RoutingPolicy { |
94 | | - assertPositiveInteger("embeddingDimension", modelProfile.embeddingDimension); |
95 | | - |
96 | | - const derivation: RoutingPolicyDerivation = { |
97 | | - ...DEFAULT_ROUTING_POLICY_DERIVATION, |
98 | | - ...overrides, |
99 | | - }; |
100 | | - |
101 | | - validateDerivation(derivation); |
102 | | - |
103 | | - const dimIn = modelProfile.embeddingDimension; |
104 | | - const broadDim = deriveProjectionDim(dimIn, derivation.broadDimRatio, derivation); |
105 | | - const normalDim = deriveProjectionDim(dimIn, derivation.normalDimRatio, derivation); |
106 | | - const narrowDim = deriveProjectionDim(dimIn, derivation.narrowDimRatio, derivation); |
107 | | - |
108 | | - const broadOffset = 0; |
109 | | - const normalOffset = broadOffset + broadDim * dimIn; |
110 | | - const narrowOffset = normalOffset + normalDim * dimIn; |
111 | | - |
112 | | - return { |
113 | | - broad: { |
114 | | - dimIn, |
115 | | - dimOut: broadDim, |
116 | | - bits: derivation.broadHashBits, |
117 | | - offset: broadOffset, |
118 | | - }, |
119 | | - normal: { |
120 | | - dimIn, |
121 | | - dimOut: normalDim, |
122 | | - offset: normalOffset, |
123 | | - }, |
124 | | - narrow: { |
125 | | - dimIn, |
126 | | - dimOut: narrowDim, |
127 | | - offset: narrowOffset, |
128 | | - }, |
129 | | - }; |
130 | | -} |
131 | | - |
132 | | -export function resolveRoutingPolicyForModel( |
133 | | - input: ResolveModelProfileInput, |
134 | | - options: ResolveRoutingPolicyOptions = {}, |
135 | | -): ResolvedRoutingPolicy { |
136 | | - const resolver = |
137 | | - options.resolver ?? new ModelProfileResolver(options.resolverOptions); |
138 | | - const modelProfile = resolver.resolve(input); |
139 | | - |
140 | | - return { |
141 | | - modelProfile, |
142 | | - routingPolicy: createRoutingPolicy( |
143 | | - modelProfile, |
144 | | - options.routingPolicyOverrides, |
145 | | - ), |
146 | | - }; |
147 | | -} |
| 1 | +export * from "./lib/Policy"; |
0 commit comments