|
| 1 | +import type { EmbeddingBackend } from "./EmbeddingBackend"; |
| 2 | + |
| 3 | +export const DEFAULT_DUMMY_EMBEDDING_DIMENSION = 1024; |
| 4 | +export const SHA256_BLOCK_BYTES = 64; |
| 5 | + |
| 6 | +const SHA256_DIGEST_BYTES = 32; |
| 7 | +const COUNTER_BYTES = 4; |
| 8 | +const BYTE_TO_UNIT_SCALE = 127.5; |
| 9 | + |
| 10 | +export interface DeterministicDummyEmbeddingBackendOptions { |
| 11 | + dimension?: number; |
| 12 | + blockBytes?: number; |
| 13 | +} |
| 14 | + |
| 15 | +function assertPositiveInteger(name: string, value: number): void { |
| 16 | + if (!Number.isInteger(value) || value <= 0) { |
| 17 | + throw new Error(`${name} must be a positive integer`); |
| 18 | + } |
| 19 | +} |
| 20 | + |
| 21 | +function getSubtleCrypto(): SubtleCrypto { |
| 22 | + const subtle = globalThis.crypto?.subtle; |
| 23 | + if (!subtle) { |
| 24 | + throw new Error("SubtleCrypto is required for DeterministicDummyEmbeddingBackend"); |
| 25 | + } |
| 26 | + return subtle; |
| 27 | +} |
| 28 | + |
| 29 | +function padBytesToBoundary(input: Uint8Array, blockBytes: number): Uint8Array { |
| 30 | + const remainder = input.byteLength % blockBytes; |
| 31 | + if (remainder === 0) { |
| 32 | + return input; |
| 33 | + } |
| 34 | + |
| 35 | + const padLength = blockBytes - remainder; |
| 36 | + const padded = new Uint8Array(input.byteLength + padLength); |
| 37 | + padded.set(input); |
| 38 | + return padded; |
| 39 | +} |
| 40 | + |
| 41 | +function byteToUnitFloat(byteValue: number): number { |
| 42 | + return byteValue / BYTE_TO_UNIT_SCALE - 1; |
| 43 | +} |
| 44 | + |
| 45 | +export class DeterministicDummyEmbeddingBackend implements EmbeddingBackend { |
| 46 | + readonly kind = "dummy-sha256" as const; |
| 47 | + readonly dimension: number; |
| 48 | + |
| 49 | + private readonly blockBytes: number; |
| 50 | + private readonly subtle = getSubtleCrypto(); |
| 51 | + private readonly encoder = new TextEncoder(); |
| 52 | + |
| 53 | + constructor(options: DeterministicDummyEmbeddingBackendOptions = {}) { |
| 54 | + this.dimension = options.dimension ?? DEFAULT_DUMMY_EMBEDDING_DIMENSION; |
| 55 | + this.blockBytes = options.blockBytes ?? SHA256_BLOCK_BYTES; |
| 56 | + |
| 57 | + assertPositiveInteger("dimension", this.dimension); |
| 58 | + assertPositiveInteger("blockBytes", this.blockBytes); |
| 59 | + } |
| 60 | + |
| 61 | + async embed(texts: string[]): Promise<Float32Array[]> { |
| 62 | + return Promise.all(texts.map((text) => this.embedOne(text))); |
| 63 | + } |
| 64 | + |
| 65 | + private async embedOne(text: string): Promise<Float32Array> { |
| 66 | + const sourceBytes = padBytesToBoundary( |
| 67 | + this.encoder.encode(text), |
| 68 | + this.blockBytes, |
| 69 | + ); |
| 70 | + |
| 71 | + const embedding = new Float32Array(this.dimension); |
| 72 | + let counter = 0; |
| 73 | + let writeIndex = 0; |
| 74 | + |
| 75 | + while (writeIndex < this.dimension) { |
| 76 | + const digest = await this.digestWithCounter(sourceBytes, counter); |
| 77 | + for ( |
| 78 | + let digestIndex = 0; |
| 79 | + digestIndex < SHA256_DIGEST_BYTES && writeIndex < this.dimension; |
| 80 | + digestIndex++ |
| 81 | + ) { |
| 82 | + embedding[writeIndex] = byteToUnitFloat(digest[digestIndex]); |
| 83 | + writeIndex++; |
| 84 | + } |
| 85 | + counter++; |
| 86 | + } |
| 87 | + |
| 88 | + return embedding; |
| 89 | + } |
| 90 | + |
| 91 | + private async digestWithCounter( |
| 92 | + sourceBytes: Uint8Array, |
| 93 | + counter: number, |
| 94 | + ): Promise<Uint8Array> { |
| 95 | + const payload = new Uint8Array(sourceBytes.byteLength + COUNTER_BYTES); |
| 96 | + payload.set(sourceBytes, 0); |
| 97 | + |
| 98 | + const counterView = new DataView( |
| 99 | + payload.buffer, |
| 100 | + payload.byteOffset + sourceBytes.byteLength, |
| 101 | + COUNTER_BYTES, |
| 102 | + ); |
| 103 | + counterView.setUint32(0, counter, false); |
| 104 | + |
| 105 | + const digest = await this.subtle.digest("SHA-256", payload); |
| 106 | + return new Uint8Array(digest); |
| 107 | + } |
| 108 | +} |
0 commit comments