diff --git a/src/index.ts b/src/index.ts index 84dbd6b..9c19452 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,3 @@ -import { connect, Connection, Tx, Config, FullResult } from '@tidbcloud/serverless'; - import { CompiledQuery, DatabaseConnection, @@ -8,11 +6,22 @@ import { Driver, Kysely, MysqlAdapter, - MysqlQueryCompiler, MysqlIntrospector, + MysqlQueryCompiler, QueryCompiler, QueryResult, } from 'kysely'; +import { Config, Connection, FullResult, Tx, connect } from '@tidbcloud/serverless'; +import { + cosineDistance, + cosineSimilarity, + innerProduct, + l1Distance, + l2Distance, + negativeInnerProduct, + vectorFromSql, + vectorToSql, +} from './vector'; /** * Config for the TiDB Serverless dialect. @@ -171,3 +180,14 @@ class TiDBServerlessTransaction { await this.#tx.rollback(); } } + +export { + cosineDistance, + cosineSimilarity, + innerProduct, + l1Distance, + l2Distance, + negativeInnerProduct, + vectorFromSql, + vectorToSql, +}; diff --git a/src/vector.ts b/src/vector.ts new file mode 100644 index 0000000..1137988 --- /dev/null +++ b/src/vector.ts @@ -0,0 +1,62 @@ +import { ExpressionBuilder, ReferenceExpression } from 'kysely'; + +type VectorLike = Float32Array | number[]; + +export function l1Distance = ReferenceExpression>( + eb: ExpressionBuilder, + column: RE, + value: VectorLike +) { + return eb.fn('VEC_L1_DISTANCE', [column, (eb) => eb.val(vectorToSql(value))]); +} + +export function l2Distance = ReferenceExpression>( + eb: ExpressionBuilder, + column: RE, + value: VectorLike +) { + return eb.fn('VEC_L2_DISTANCE', [column, (eb) => eb.val(vectorToSql(value))]); +} + +export function negativeInnerProduct< + DB, + TB extends keyof DB, + RE extends ReferenceExpression = ReferenceExpression +>(eb: ExpressionBuilder, column: RE, value: VectorLike) { + return eb.fn('VEC_NEGATIVE_INNER_PRODUCT', [column, (eb) => eb.val(vectorToSql(value))]); +} + +export function innerProduct = ReferenceExpression>( + eb: ExpressionBuilder, + column: RE, + value: VectorLike +) { + return eb.neg(negativeInnerProduct(eb, column, value)); +} + +export function cosineDistance = ReferenceExpression>( + eb: ExpressionBuilder, + column: RE, + value: VectorLike +) { + return eb.fn('VEC_COSINE_DISTANCE', [column, (eb) => eb.val(vectorToSql(value))]); +} + +export function cosineSimilarity = ReferenceExpression>( + eb: ExpressionBuilder, + column: RE, + value: VectorLike +) { + return eb(eb.lit(1), '-', cosineDistance(eb, column, value)); +} + +export function vectorFromSql(value: string) { + return value + .substring(1, value.length - 1) + .split(',') + .map((v) => parseFloat(v)); +} + +export function vectorToSql(vector: VectorLike) { + return `[${vector.join(',')}]`; +}