diff --git a/.github/workflows/release-feature-branch.yaml b/.github/workflows/release-feature-branch.yaml index 414f98f053..db482b5fc5 100644 --- a/.github/workflows/release-feature-branch.yaml +++ b/.github/workflows/release-feature-branch.yaml @@ -237,7 +237,7 @@ jobs: - uses: actions/setup-node@v4 with: - node-version: '18.18' + node-version: '22' registry-url: 'https://registry.npmjs.org' - uses: pnpm/action-setup@v3 @@ -334,7 +334,7 @@ jobs: - uses: actions/setup-node@v4 with: - node-version: '18.18' + node-version: '22' registry-url: 'https://registry.npmjs.org' - uses: pnpm/action-setup@v3 @@ -415,4 +415,4 @@ jobs: working-directory: ${{ matrix.package }} shell: bash env: - NODE_AUTH_TOKEN: ${{ secrets.NPM_ACCESS_TOKEN }} + NODE_AUTH_TOKEN: ${{ secrets.NPM_ACCESS_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/release-latest.yaml b/.github/workflows/release-latest.yaml index 13e37d4e58..6f31f99c5b 100644 --- a/.github/workflows/release-latest.yaml +++ b/.github/workflows/release-latest.yaml @@ -360,7 +360,7 @@ jobs: - uses: actions/setup-node@v4 with: - node-version: '18.18' + node-version: '22' registry-url: 'https://registry.npmjs.org' - uses: pnpm/action-setup@v3 diff --git a/.github/workflows/unpublish-release-feature-branch.yaml b/.github/workflows/unpublish-release-feature-branch.yaml index e963a0a461..4c27303f55 100644 --- a/.github/workflows/unpublish-release-feature-branch.yaml +++ b/.github/workflows/unpublish-release-feature-branch.yaml @@ -21,7 +21,7 @@ jobs: - uses: actions/setup-node@v4 with: - node-version: '18.18' + node-version: '22' registry-url: 'https://registry.npmjs.org' - name: Unpublish diff --git a/.nvmrc b/.nvmrc index 4a58985bb4..8fdd954df9 100644 --- a/.nvmrc +++ b/.nvmrc @@ -1 +1 @@ -18.18 +22 \ No newline at end of file diff --git a/changelogs/drizzle-orm/0.44.0.md b/changelogs/drizzle-orm/0.44.0.md new file mode 100644 index 0000000000..bc04ee57db --- /dev/null +++ b/changelogs/drizzle-orm/0.44.0.md @@ -0,0 +1,91 @@ +## Error handling + +Starting from this version, we’ve introduced a new `DrizzleQueryError` that wraps all errors from database drivers and provides a set of useful information: + +1. A proper stack trace to identify which exact `Drizzle` query failed +2. The generated SQL string and its parameters +3. The original stack trace from the driver that caused the DrizzleQueryError + +## Drizzle `cache` module + +Drizzle sends every query straight to your database by default. There are no hidden actions, no automatic caching or invalidation - you’ll always see exactly what runs. If you want caching, you must opt in. + +By default, Drizzle uses a explicit caching strategy (i.e. `global: false`), so nothing is ever cached unless you ask. This prevents surprises or hidden performance traps in your application. Alternatively, you can flip on all caching (global: true) so that every select will look in cache first. + +Out first native integration was built together with Upstash team and let you natively use `upstash` as a cache for your drizzle queries + +```ts +import { upstashCache } from "drizzle-orm/cache/upstash"; +import { drizzle } from "drizzle-orm/..."; + +const db = drizzle(process.env.DB_URL!, { + cache: upstashCache({ + // 👇 Redis credentials (optional — can also be pulled from env vars) + url: '', + token: '', + // 👇 Enable caching for all queries by default (optional) + global: true, + // 👇 Default cache behavior (optional) + config: { ex: 60 } + }) +}); +``` + +You can also implement your own cache, as Drizzle exposes all the necessary APIs, such as get, put, mutate, etc. +You can find full implementation details on the [website](https://orm.drizzle.team/docs/cache#custom-cache) + +```ts +import Keyv from "keyv"; +export class TestGlobalCache extends Cache { + private globalTtl: number = 1000; + // This object will be used to store which query keys were used + // for a specific table, so we can later use it for invalidation. + private usedTablesPerKey: Record = {}; + constructor(private kv: Keyv = new Keyv()) { + super(); + } + // For the strategy, we have two options: + // - 'explicit': The cache is used only when .$withCache() is added to a query. + // - 'all': All queries are cached globally. + // The default behavior is 'explicit'. + override strategy(): "explicit" | "all" { + return "all"; + } + // This function accepts query and parameters that cached into key param, + // allowing you to retrieve response values for this query from the cache. + override async get(key: string): Promise { + ... + } + // This function accepts several options to define how cached data will be stored: + // - 'key': A hashed query and parameters. + // - 'response': An array of values returned by Drizzle from the database. + // - 'tables': An array of tables involved in the select queries. This information is needed for cache invalidation. + // + // For example, if a query uses the "users" and "posts" tables, you can store this information. Later, when the app executes + // any mutation statements on these tables, you can remove the corresponding key from the cache. + // If you're okay with eventual consistency for your queries, you can skip this option. + override async put( + key: string, + response: any, + tables: string[], + config?: CacheConfig, + ): Promise { + ... + } + // This function is called when insert, update, or delete statements are executed. + // You can either skip this step or invalidate queries that used the affected tables. + // + // The function receives an object with two keys: + // - 'tags': Used for queries labeled with a specific tag, allowing you to invalidate by that tag. + // - 'tables': The actual tables affected by the insert, update, or delete statements, + // helping you track which tables have changed since the last cache update. + override async onMutate(params: { + tags: string | string[]; + tables: string | string[] | Table | Table[]; + }): Promise { + ... + } +} +``` + +For more usage example you can check our [docs](https://orm.drizzle.team/docs/cache#cache-usage-examples) \ No newline at end of file diff --git a/drizzle-orm/package.json b/drizzle-orm/package.json index 66aae63cf4..411fd39db1 100644 --- a/drizzle-orm/package.json +++ b/drizzle-orm/package.json @@ -1,6 +1,6 @@ { "name": "drizzle-orm", - "version": "0.43.1", + "version": "0.44.0", "description": "Drizzle ORM package for SQL databases", "type": "module", "scripts": { @@ -64,14 +64,15 @@ "better-sqlite3": ">=7", "bun-types": "*", "expo-sqlite": ">=14.0.0", - "gel": ">=2", "knex": "*", "kysely": "*", "mysql2": ">=2", "pg": ">=8", "postgres": ">=3", "sql.js": ">=1", - "sqlite3": ">=5" + "sqlite3": ">=5", + "gel": ">=2", + "@upstash/redis": ">=1.34.7" }, "peerDependenciesMeta": { "mysql2": { @@ -157,6 +158,9 @@ }, "@prisma/client": { "optional": true + }, + "@upstash/redis": { + "optional": true } }, "devDependencies": { @@ -173,11 +177,12 @@ "@planetscale/database": "^1.16.0", "@prisma/client": "5.14.0", "@tidbcloud/serverless": "^0.1.1", - "@types/better-sqlite3": "^7.6.4", + "@types/better-sqlite3": "^7.6.12", "@types/node": "^20.2.5", "@types/pg": "^8.10.1", "@types/react": "^18.2.45", "@types/sql.js": "^1.4.4", + "@upstash/redis": "^1.34.3", "@vercel/postgres": "^0.8.0", "@xata.io/client": "^0.29.3", "better-sqlite3": "^11.9.1", diff --git a/drizzle-orm/src/aws-data-api/pg/driver.ts b/drizzle-orm/src/aws-data-api/pg/driver.ts index 1a02723a65..7395930ad0 100644 --- a/drizzle-orm/src/aws-data-api/pg/driver.ts +++ b/drizzle-orm/src/aws-data-api/pg/driver.ts @@ -21,6 +21,7 @@ import { AwsDataApiSession } from './session.ts'; export interface PgDriverOptions { logger?: Logger; + cache?: Cache; database: string; resourceArn: string; secretArn: string; @@ -118,9 +119,13 @@ function construct = Record db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/aws-data-api/pg/session.ts b/drizzle-orm/src/aws-data-api/pg/session.ts index 974f6d3ff6..6c915967f7 100644 --- a/drizzle-orm/src/aws-data-api/pg/session.ts +++ b/drizzle-orm/src/aws-data-api/pg/session.ts @@ -5,6 +5,9 @@ import { ExecuteStatementCommand, RollbackTransactionCommand, } from '@aws-sdk/client-rds-data'; +import type { Cache } from '~/cache/core/cache.ts'; +import { NoopCache } from '~/cache/core/cache.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { @@ -33,17 +36,23 @@ export class AwsDataApiPreparedQuery< constructor( private client: AwsDataApiClient, - queryString: string, + private queryString: string, private params: unknown[], private typings: QueryTypingsValue[], private options: AwsDataApiSessionOptions, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, /** @internal */ readonly transactionId: string | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], ) { - super({ sql: queryString, params }); + super({ sql: queryString, params }, cache, queryMetadata, cacheConfig); this.rawQuery = new ExecuteStatementCommand({ sql: queryString, parameters: [], @@ -108,7 +117,9 @@ export class AwsDataApiPreparedQuery< this.options.logger?.logQuery(this.rawQuery.input.sql!, this.rawQuery.input.parameters); - const result = await this.client.send(this.rawQuery); + const result = await this.queryWithCache(this.queryString, params, async () => { + return await this.client.send(this.rawQuery); + }); const rows = result.records?.map((row) => { return row.map((field) => getValueFromDataApi(field)); }) ?? []; @@ -139,6 +150,7 @@ export class AwsDataApiPreparedQuery< export interface AwsDataApiSessionOptions { logger?: Logger; + cache?: Cache; database: string; resourceArn: string; secretArn: string; @@ -158,6 +170,7 @@ export class AwsDataApiSession< /** @internal */ readonly rawQuery: AwsDataApiQueryBase; + private cache: Cache; constructor( /** @internal */ @@ -174,6 +187,7 @@ export class AwsDataApiSession< resourceArn: options.resourceArn, database: options.database, }; + this.cache = options.cache ?? new NoopCache(); } prepareQuery< @@ -188,6 +202,8 @@ export class AwsDataApiSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { type: 'select' | 'update' | 'delete' | 'insert'; tables: string[] }, + cacheConfig?: WithCacheConfig, transactionId?: string, ): AwsDataApiPreparedQuery { return new AwsDataApiPreparedQuery( @@ -196,6 +212,9 @@ export class AwsDataApiSession< query.params, query.typings ?? [], this.options, + this.cache, + queryMetadata, + cacheConfig, fields, transactionId ?? this.transactionId, isResponseInArrayMode, @@ -210,6 +229,8 @@ export class AwsDataApiSession< undefined, false, undefined, + undefined, + undefined, this.transactionId, ).execute(); } diff --git a/drizzle-orm/src/better-sqlite3/driver.ts b/drizzle-orm/src/better-sqlite3/driver.ts index 68eac53aca..fc91b8f934 100644 --- a/drizzle-orm/src/better-sqlite3/driver.ts +++ b/drizzle-orm/src/better-sqlite3/driver.ts @@ -29,7 +29,7 @@ export class BetterSQLite3Database = Rec function construct = Record>( client: Database, - config: DrizzleConfig = {}, + config: Omit, 'cache'> = {}, ): BetterSQLite3Database & { $client: Database; } { @@ -57,6 +57,10 @@ function construct = Record db).$client = client; + // ( db).$cache = config.cache; + // if (( db).$cache) { + // ( db).$cache['invalidate'] = config.cache?.onMutate; + // } return db as any; } diff --git a/drizzle-orm/src/better-sqlite3/session.ts b/drizzle-orm/src/better-sqlite3/session.ts index 8a02eb37e2..3f67934651 100644 --- a/drizzle-orm/src/better-sqlite3/session.ts +++ b/drizzle-orm/src/better-sqlite3/session.ts @@ -1,4 +1,6 @@ import type { Database, RunResult, Statement } from 'better-sqlite3'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -18,6 +20,7 @@ import { mapResultRow } from '~/utils.ts'; export interface BetterSQLiteSessionOptions { logger?: Logger; + cache?: Cache; } type PreparedQueryConfig = Omit; @@ -29,6 +32,7 @@ export class BetterSQLiteSession< static override readonly [entityKind]: string = 'BetterSQLiteSession'; private logger: Logger; + private cache: Cache; constructor( private client: Database, @@ -38,6 +42,7 @@ export class BetterSQLiteSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery>( @@ -46,12 +51,20 @@ export class BetterSQLiteSession< executeMethod: SQLiteExecuteMethod, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => unknown, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PreparedQuery { const stmt = this.client.prepare(query.sql); return new PreparedQuery( stmt, query, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, executeMethod, isResponseInArrayMode, @@ -99,12 +112,18 @@ export class PreparedQuery private stmt: Statement, query: Query, private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, executeMethod: SQLiteExecuteMethod, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => unknown, ) { - super('sync', executeMethod, query); + super('sync', executeMethod, query, cache, queryMetadata, cacheConfig); } run(placeholderValues?: Record): RunResult { diff --git a/drizzle-orm/src/bun-sql/driver.ts b/drizzle-orm/src/bun-sql/driver.ts index 1b2c42c4f4..8e930510f6 100644 --- a/drizzle-orm/src/bun-sql/driver.ts +++ b/drizzle-orm/src/bun-sql/driver.ts @@ -49,9 +49,13 @@ function construct = Record; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/bun-sql/session.ts b/drizzle-orm/src/bun-sql/session.ts index 17fe520c41..6aa62faa0b 100644 --- a/drizzle-orm/src/bun-sql/session.ts +++ b/drizzle-orm/src/bun-sql/session.ts @@ -1,6 +1,8 @@ /// import type { SavepointSQL, SQL, TransactionSQL } from 'bun'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -22,11 +24,17 @@ export class BunSQLPreparedQuery extends PgPrepar private queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], ) { - super({ sql: queryString, params }); + super({ sql: queryString, params }, cache, queryMetadata, cacheConfig); } async execute(placeholderValues: Record | undefined = {}): Promise { @@ -42,18 +50,22 @@ export class BunSQLPreparedQuery extends PgPrepar const { fields, queryString: query, client, joinsNotNullableMap, customResultMapper } = this; if (!fields && !customResultMapper) { - return tracer.startActiveSpan('drizzle.driver.execute', () => { - return client.unsafe(query, params as any[]); + return tracer.startActiveSpan('drizzle.driver.execute', async () => { + return await this.queryWithCache(query, params, async () => { + return await client.unsafe(query, params as any[]); + }); }); } - const rows: any[] = await tracer.startActiveSpan('drizzle.driver.execute', () => { + const rows: any[] = await tracer.startActiveSpan('drizzle.driver.execute', async () => { span?.setAttributes({ 'drizzle.query.text': query, 'drizzle.query.params': JSON.stringify(params), }); - return client.unsafe(query, params as any[]).values(); + return await this.queryWithCache(query, params, async () => { + return client.unsafe(query, params as any[]).values(); + }); }); return tracer.startActiveSpan('drizzle.mapResponse', () => { @@ -72,12 +84,14 @@ export class BunSQLPreparedQuery extends PgPrepar 'drizzle.query.params': JSON.stringify(params), }); this.logger.logQuery(this.queryString, params); - return tracer.startActiveSpan('drizzle.driver.execute', () => { + return tracer.startActiveSpan('drizzle.driver.execute', async () => { span?.setAttributes({ 'drizzle.query.text': this.queryString, 'drizzle.query.params': JSON.stringify(params), }); - return this.client.unsafe(this.queryString, params as any[]); + return await this.queryWithCache(this.queryString, params, async () => { + return await this.client.unsafe(this.queryString, params as any[]); + }); }); }); } @@ -90,6 +104,7 @@ export class BunSQLPreparedQuery extends PgPrepar export interface BunSQLSessionOptions { logger?: Logger; + cache?: Cache; } export class BunSQLSession< @@ -100,6 +115,7 @@ export class BunSQLSession< static override readonly [entityKind]: string = 'BunSQLSession'; logger: Logger; + private cache: Cache; constructor( public client: TSQL, @@ -110,6 +126,7 @@ export class BunSQLSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -118,12 +135,20 @@ export class BunSQLSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PgPreparedQuery { return new BunSQLPreparedQuery( this.client, query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, isResponseInArrayMode, customResultMapper, diff --git a/drizzle-orm/src/cache/core/cache.ts b/drizzle-orm/src/cache/core/cache.ts new file mode 100644 index 0000000000..4f71a273f0 --- /dev/null +++ b/drizzle-orm/src/cache/core/cache.ts @@ -0,0 +1,78 @@ +import { entityKind } from '~/entity.ts'; +import type { Table } from '~/index.ts'; +import type { CacheConfig } from './types.ts'; + +export abstract class Cache { + static readonly [entityKind]: string = 'Cache'; + + abstract strategy(): 'explicit' | 'all'; + + /** + * Invoked if we should check cache for cached response + * @param sql + * @param tables + */ + abstract get( + key: string, + tables: string[], + isTag: boolean, + isAutoInvalidate?: boolean, + ): Promise; + + /** + * Invoked if new query should be inserted to cache + * @param sql + * @param tables + */ + abstract put( + hashedQuery: string, + response: any, + tables: string[], + isTag: boolean, + config?: CacheConfig, + ): Promise; + + /** + * Invoked if insert, update, delete was invoked + * @param tables + */ + abstract onMutate( + params: MutationOption, + ): Promise; +} + +export class NoopCache extends Cache { + override strategy() { + return 'all' as const; + } + + static override readonly [entityKind]: string = 'NoopCache'; + + override async get(_key: string): Promise { + return undefined; + } + override async put( + _hashedQuery: string, + _response: any, + _tables: string[], + _config?: any, + ): Promise { + // noop + } + override async onMutate(_params: MutationOption): Promise { + // noop + } +} + +export type MutationOption = { tags?: string | string[]; tables?: Table | Table[] | string | string[] }; + +export async function hashQuery(sql: string, params?: any[]) { + const dataToHash = `${sql}-${JSON.stringify(params)}`; + const encoder = new TextEncoder(); + const data = encoder.encode(dataToHash); + const hashBuffer = await crypto.subtle.digest('SHA-256', data); + const hashArray = [...new Uint8Array(hashBuffer)]; + const hashHex = hashArray.map((b) => b.toString(16).padStart(2, '0')).join(''); + + return hashHex; +} diff --git a/drizzle-orm/src/cache/core/index.ts b/drizzle-orm/src/cache/core/index.ts new file mode 100644 index 0000000000..ba0d58441d --- /dev/null +++ b/drizzle-orm/src/cache/core/index.ts @@ -0,0 +1 @@ +export * from './cache.ts'; diff --git a/drizzle-orm/src/cache/core/types.ts b/drizzle-orm/src/cache/core/types.ts new file mode 100644 index 0000000000..1b31193cec --- /dev/null +++ b/drizzle-orm/src/cache/core/types.ts @@ -0,0 +1,29 @@ +export type CacheConfig = { + /** + * expire time, in seconds (a positive integer) + */ + ex?: number; + /** + * expire time, in milliseconds (a positive integer). + */ + px?: number; + /** + * Unix time at which the key will expire, in seconds (a positive integer). + */ + exat?: number; + /** + * Unix time at which the key will expire, in milliseconds (a positive integer) + */ + pxat?: number; + /** + * Retain the time to live associated with the key. + */ + keepTtl?: boolean; + /** + * Set an expiration (TTL or time to live) on one or more fields of a given hash key. + * Used for HEXPIRE command + */ + hexOptions?: 'NX' | 'nx' | 'XX' | 'xx' | 'GT' | 'gt' | 'LT' | 'lt'; +}; + +export type WithCacheConfig = { enable: boolean; config?: CacheConfig; tag?: string; autoInvalidate?: boolean }; diff --git a/drizzle-orm/src/cache/readme.md b/drizzle-orm/src/cache/readme.md new file mode 100644 index 0000000000..0c1a8b63e1 --- /dev/null +++ b/drizzle-orm/src/cache/readme.md @@ -0,0 +1,230 @@ +## Caching with Drizzle + +By default, Drizzle does not perform any implicit actions with your queries and mapping. There is no cache under the hood—each query is sent directly to your database, and you can actually see it. + +However, there are cases when you might want to implement a simple caching logic for specific queries or even for all queries. With Drizzle's cache option, you can define how and when the cache is used, how you store and retrieve data, and what actions to take when write statements are executed on the database. It's basically similar to `beforeQuery` hooks, that will be invoked before actual query will be executed. Additionally, Drizzle provides predefined logic for caching. Let's take a look at it. + +To make cache work you would need to define cache callbacks in drizzle instance or use a predefined ones we have in Drizzle, like a `upstashCache()` that was built together with Upstash team + +### Cache overview + +**Using upstash cache with drizzle** + +```ts +const db = drizzle(process.env.DB_URL!, { cache: upstashCache() }) +``` + +You can also define custom logic for your cache behavior. This is an example of our NodeKV implementation for the Drizzle cache test suites + +```ts +const db = drizzle(process.env.DB_URL!, { cache: new TestGlobalCache() }) +``` + +```ts +import Keyv from 'keyv'; + +export class TestGlobalCache extends Cache { + private globalTtl: number = 1000; + // This object will be used to store which query keys were used + // for a specific table, so we can later use it for invalidation. + private usedTablesPerKey: Record = {}; + + constructor(private kv: Keyv = new Keyv()) { + super(); + } + + // For the strategy, we have two options: + // - 'explicit': The cache is used only when .$withCache() is added to a query. + // - 'all': All queries are cached globally. + // The default behavior is 'explicit'. + override strategy(): 'explicit' | 'all' { + return 'all'; + } + + // This function accepts query and parameters that cached into key param, + // allowing you to retrieve response values for this query from the cache. + override async get(key: string): Promise { + const res = await this.kv.get(key) ?? undefined; + return res; + } + + // This function accepts several options to define how cached data will be stored: + // - 'key': A hashed query and parameters. + // - 'response': An array of values returned by Drizzle from the database. + // - 'tables': An array of tables involved in the select queries. This information is needed for cache invalidation. + // + // For example, if a query uses the "users" and "posts" tables, you can store this information. Later, when the app executes + // any mutation statements on these tables, you can remove the corresponding key from the cache. + // If you're okay with eventual consistency for your queries, you can skip this option. + override async put(key: string, response: any, tables: string[], config?: CacheConfig): Promise { + await this.kv.set(key, response, config ? config.ex : this.globalTtl); + for (const table of tables) { + const keys = this.usedTablesPerKey[table]; + if (keys === undefined) { + this.usedTablesPerKey[table] = [key]; + } else { + keys.push(key); + } + } + } + + // This function is called when insert, update, or delete statements are executed. + // You can either skip this step or invalidate queries that used the affected tables. + // + // The function receives an object with two keys: + // - 'tags': Used for queries labeled with a specific tag, allowing you to invalidate by that tag. + // - 'tables': The actual tables affected by the insert, update, or delete statements, + // helping you track which tables have changed since the last cache update. + override async onMutate(params: { tags: string | string[], tables: string | string[] | Table | Table[]}): Promise { + const tagsArray = params.tags ? Array.isArray(params.tags) ? params.tags : [params.tags] : []; + const tablesArray = params.tables ? Array.isArray(params.tables) ? params.tables : [params.tables] : []; + + const keysToDelete = new Set(); + + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + const keys = this.usedTablesPerKey[tableName] ?? []; + for (const key of keys) keysToDelete.add(key); + } + + if (keysToDelete.size > 0 || tagsArray.length > 0) { + for (const tag of tagsArray) { + await this.kv.delete(tag); + } + + for (const key of keysToDelete) { + await this.kv.delete(key); + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + this.usedTablesPerKey[tableName] = []; + } + } + } + } +} +``` + +### Cache definition + +**Define cache credentials, but no cache will be used globally for all queries** + +```ts +const db = drizzle(process.env.DB_URL!, { cache: upstashCache({ url: '', token: '' }) }) +``` + +**Define cache credentials, and the cache will be used globally for all queries** + +```ts +const db = drizzle(process.env.DB_URL!, { cache: upstashCache({ url: '', token: '', global: true }) }) +``` + +**Define cache credentials with custom config values to be used for all queries, unless overridden** + +```ts +const db = drizzle(process.env.DB_URL!, { cache: upstashCache({ url: '', token: '', global: true, config: {} }) }) +``` + +These are all the possible config values that Drizzle supports with the cache layer + +```ts +export type CacheConfig = { + /** + * expire time, in seconds (a positive integer) + */ + ex?: number; + /** + * expire time, in milliseconds (a positive integer). + */ + px?: number; + /** + * Unix time at which the key will expire, in seconds (a positive integer). + */ + exat?: number; + /** + * Unix time at which the key will expire, in milliseconds (a positive integer) + */ + pxat?: number; + /** + * Retain the time to live associated with the key. + */ + keepTtl?: boolean; +}; + +``` + +### Cache usage + +Once you've provided all the necessary instructions to the Drizzle database instance, you can now use the cache with Drizzle + +**Case 1: Drizzle with global: false option** + +```ts +const db = drizzle(process.env.DB_URL!, { cache: upstashCache({ url: '', token: '' }) }) +``` + +In this case, the current query won't use the cache + +```ts +const res = await db.select().from(users) + +// However, any mutate operation will trigger the onMutate function in the cache +// and attempt to invalidate queries that used the tables involved in this mutation query. +await db.insert(users).value({ email: 'cacheman@upstash.com' }) +``` + +If you want the query to actually use the cache, you need to call `.$withCache()` + +```ts +const res = await db.select().from(users).$withCache() +``` + +`.$withCache` has a set of options you can use to manage and config this specific query strategy + +```ts +// rewrite the global config options for this specific query +.$withCache({ config: {} }) + +// give a query custom cache key instead of hashing query+params under the hood +.$withCache({ tag: 'custom_key' }) + +// disable autoinvalidation for this query, if you are fine with eventual consstnecy for this specific query +.$withCache({ autoInvalidate: false }) +``` + +**Case 2: Drizzle with global: true option** + +```ts +const db = drizzle(process.env.DB_URL!, { cache: upstashCache({ url: '', token: '', global: true }) }) +``` + +In this case, the current query will use the cache + +```ts +const res = await db.select().from(users) +``` + +If you want the query to disable cache for some specific query, you need to call `.$withCache(false)` + +```ts +// cache is disabled for this query +const res = await db.select().from(users).$withCache(false) +``` + +You can also use cache instance from a `db` to force invalidate specific tables or tags you've defined previously + +```ts +// Invalidate all queries that use the `users` table. You can do this with the Drizzle instance. +await db.$cache?.invalidate({ tables: users }); +// or +await db.$cache?.invalidate({ tables: [users, posts] }); + +// Invalidate all queries that use the `usersTable`. You can do this by using just the table name. +await db.$cache?.invalidate({ tables: 'usersTable' }); +// or +await db.$cache?.invalidate({ tables: ['usersTable' , 'postsTable' ] }); + +// You can also invalidate custom tags defined in any previously executed select queries. +await db.$cache?.invalidate({ tags: 'custom_key' }); +// or +await db.$cache?.invalidate({ tags: ['custom_key', 'custom_key1'] }); +``` diff --git a/drizzle-orm/src/cache/upstash/cache.ts b/drizzle-orm/src/cache/upstash/cache.ts new file mode 100644 index 0000000000..8e6a54c1a8 --- /dev/null +++ b/drizzle-orm/src/cache/upstash/cache.ts @@ -0,0 +1,213 @@ +import { Redis } from '@upstash/redis'; +import type { MutationOption } from '~/cache/core/index.ts'; +import { Cache } from '~/cache/core/index.ts'; +import { entityKind, is } from '~/entity.ts'; +import { OriginalName, Table } from '~/index.ts'; +import type { CacheConfig } from '../core/types.ts'; + +const getByTagScript = ` +local tagsMapKey = KEYS[1] -- tags map key +local tag = ARGV[1] -- tag + +local compositeTableName = redis.call('HGET', tagsMapKey, tag) +if not compositeTableName then + return nil +end + +local value = redis.call('HGET', compositeTableName, tag) +return value +`; + +const onMutateScript = ` +local tagsMapKey = KEYS[1] -- tags map key +local tables = {} -- initialize tables array +local tags = ARGV -- tags array + +for i = 2, #KEYS do + tables[#tables + 1] = KEYS[i] -- add all keys except the first one to tables +end + +if #tags > 0 then + for _, tag in ipairs(tags) do + if tag ~= nil and tag ~= '' then + local compositeTableName = redis.call('HGET', tagsMapKey, tag) + if compositeTableName then + redis.call('HDEL', compositeTableName, tag) + end + end + end + redis.call('HDEL', tagsMapKey, unpack(tags)) +end + +local keysToDelete = {} + +if #tables > 0 then + local compositeTableNames = redis.call('SUNION', unpack(tables)) + for _, compositeTableName in ipairs(compositeTableNames) do + keysToDelete[#keysToDelete + 1] = compositeTableName + end + for _, table in ipairs(tables) do + keysToDelete[#keysToDelete + 1] = table + end + redis.call('DEL', unpack(keysToDelete)) +end +`; + +type Script = ReturnType; + +type ExpireOptions = 'NX' | 'nx' | 'XX' | 'xx' | 'GT' | 'gt' | 'LT' | 'lt'; + +export class UpstashCache extends Cache { + static override readonly [entityKind]: string = 'UpstashCache'; + /** + * Prefix for sets which denote the composite table names for each unique table + * + * Example: In the composite table set of "table1", you may find + * `${compositeTablePrefix}table1,table2` and `${compositeTablePrefix}table1,table3` + */ + private static compositeTableSetPrefix = '__CTS__'; + /** + * Prefix for hashes which map hash or tags to cache values + */ + private static compositeTablePrefix = '__CT__'; + /** + * Key which holds the mapping of tags to composite table names + * + * Using this tagsMapKey, you can find the composite table name for a given tag + * and get the cache value for that tag: + * + * ```ts + * const compositeTable = redis.hget(tagsMapKey, 'tag1') + * console.log(compositeTable) // `${compositeTablePrefix}table1,table2` + * + * const cachevalue = redis.hget(compositeTable, 'tag1') + */ + private static tagsMapKey = '__tagsMap__'; + /** + * Queries whose auto invalidation is false aren't stored in their respective + * composite table hashes because those hashes are deleted when a mutation + * occurs on related tables. + * + * Instead, they are stored in a separate hash with the prefix + * `__nonAutoInvalidate__` to prevent them from being deleted when a mutation + */ + private static nonAutoInvalidateTablePrefix = '__nonAutoInvalidate__'; + + private luaScripts: { + getByTagScript: Script; + onMutateScript: Script; + }; + + private internalConfig: { seconds: number; hexOptions?: ExpireOptions }; + + constructor(public redis: Redis, config?: CacheConfig, protected useGlobally?: boolean) { + super(); + this.internalConfig = this.toInternalConfig(config); + this.luaScripts = { + getByTagScript: this.redis.createScript(getByTagScript, { readonly: true }), + onMutateScript: this.redis.createScript(onMutateScript), + }; + } + + public strategy() { + return this.useGlobally ? 'all' : 'explicit'; + } + + private toInternalConfig(config?: CacheConfig): { seconds: number; hexOptions?: ExpireOptions } { + return config + ? { + seconds: config.ex!, + hexOptions: config.hexOptions, + } + : { + seconds: 1, + }; + } + + override async get( + key: string, + tables: string[], + isTag: boolean = false, + isAutoInvalidate?: boolean, + ): Promise { + if (!isAutoInvalidate) { + const result = await this.redis.hget(UpstashCache.nonAutoInvalidateTablePrefix, key); + return result === null ? undefined : result as any[]; + } + + if (isTag) { + const result = await this.luaScripts.getByTagScript.exec([UpstashCache.tagsMapKey], [key]); + return result === null ? undefined : result as any[]; + } + + // Normal cache lookup for the composite key + const compositeKey = this.getCompositeKey(tables); + const result = await this.redis.hget(compositeKey, key) ?? undefined; // Retrieve result for normal query + return result === null ? undefined : result as any[]; + } + + override async put( + key: string, + response: any, + tables: string[], + isTag: boolean = false, + config?: CacheConfig, + ): Promise { + const isAutoInvalidate = tables.length !== 0; + + const pipeline = this.redis.pipeline(); + const ttlSeconds = config && config.ex ? config.ex : this.internalConfig.seconds; + const hexOptions = config && config.hexOptions ? config.hexOptions : this.internalConfig?.hexOptions; + + if (!isAutoInvalidate) { + if (isTag) { + pipeline.hset(UpstashCache.tagsMapKey, { [key]: UpstashCache.nonAutoInvalidateTablePrefix }); + pipeline.hexpire(UpstashCache.tagsMapKey, key, ttlSeconds, hexOptions); + } + + pipeline.hset(UpstashCache.nonAutoInvalidateTablePrefix, { [key]: response }); + pipeline.hexpire(UpstashCache.nonAutoInvalidateTablePrefix, key, ttlSeconds, hexOptions); + await pipeline.exec(); + return; + } + + const compositeKey = this.getCompositeKey(tables); + + pipeline.hset(compositeKey, { [key]: response }); // Store the result with the tag under the composite key + pipeline.hexpire(compositeKey, key, ttlSeconds, hexOptions); // Set expiration for the composite key + + if (isTag) { + pipeline.hset(UpstashCache.tagsMapKey, { [key]: compositeKey }); // Store the tag and its composite key in the map + pipeline.hexpire(UpstashCache.tagsMapKey, key, ttlSeconds, hexOptions); // Set expiration for the tag + } + + for (const table of tables) { + pipeline.sadd(this.addTablePrefix(table), compositeKey); + } + + await pipeline.exec(); + } + + override async onMutate(params: MutationOption) { + const tags = Array.isArray(params.tags) ? params.tags : params.tags ? [params.tags] : []; + const tables = Array.isArray(params.tables) ? params.tables : params.tables ? [params.tables] : []; + const tableNames: string[] = tables.map((table) => is(table, Table) ? table[OriginalName] : table as string); + + const compositeTableSets = tableNames.map((table) => this.addTablePrefix(table)); + await this.luaScripts.onMutateScript.exec([UpstashCache.tagsMapKey, ...compositeTableSets], tags); + } + + private addTablePrefix = (table: string) => `${UpstashCache.compositeTableSetPrefix}${table}`; + private getCompositeKey = (tables: string[]) => `${UpstashCache.compositeTablePrefix}${tables.sort().join(',')}`; +} + +export function upstashCache( + { url, token, config, global = false }: { url: string; token: string; config?: CacheConfig; global?: boolean }, +): UpstashCache { + const redis = new Redis({ + url, + token, + }); + + return new UpstashCache(redis, config, global); +} diff --git a/drizzle-orm/src/cache/upstash/index.ts b/drizzle-orm/src/cache/upstash/index.ts new file mode 100644 index 0000000000..ba0d58441d --- /dev/null +++ b/drizzle-orm/src/cache/upstash/index.ts @@ -0,0 +1 @@ +export * from './cache.ts'; diff --git a/drizzle-orm/src/d1/driver.ts b/drizzle-orm/src/d1/driver.ts index 7b4bbdfb6f..1a688a3f04 100644 --- a/drizzle-orm/src/d1/driver.ts +++ b/drizzle-orm/src/d1/driver.ts @@ -66,9 +66,13 @@ export function drizzle< }; } - const session = new SQLiteD1Session(client as D1Database, dialect, schema, { logger }); + const session = new SQLiteD1Session(client as D1Database, dialect, schema, { logger, cache: config.cache }); const db = new DrizzleD1Database('async', dialect, session, schema) as DrizzleD1Database; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/d1/session.ts b/drizzle-orm/src/d1/session.ts index 61ef49315a..363ea4a9a7 100644 --- a/drizzle-orm/src/d1/session.ts +++ b/drizzle-orm/src/d1/session.ts @@ -1,6 +1,8 @@ /// import type { BatchItem } from '~/batch.ts'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -20,6 +22,7 @@ import { mapResultRow } from '~/utils.ts'; export interface SQLiteD1SessionOptions { logger?: Logger; + cache?: Cache; } type PreparedQueryConfig = Omit; @@ -31,6 +34,7 @@ export class SQLiteD1Session< static override readonly [entityKind]: string = 'SQLiteD1Session'; private logger: Logger; + private cache: Cache; constructor( private client: D1Database, @@ -40,6 +44,7 @@ export class SQLiteD1Session< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -48,12 +53,20 @@ export class SQLiteD1Session< executeMethod: SQLiteExecuteMethod, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => unknown, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): D1PreparedQuery { const stmt = this.client.prepare(query.sql); return new D1PreparedQuery( stmt, query, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, executeMethod, isResponseInArrayMode, @@ -166,21 +179,29 @@ export class D1PreparedQuery unknown, ) { - super('async', executeMethod, query); + super('async', executeMethod, query, cache, queryMetadata, cacheConfig); this.customResultMapper = customResultMapper; this.fields = fields; this.stmt = stmt; } - run(placeholderValues?: Record): Promise { + async run(placeholderValues?: Record): Promise { const params = fillPlaceholders(this.query.params, placeholderValues ?? {}); this.logger.logQuery(this.query.sql, params); - return this.stmt.bind(...params).run(); + return await this.queryWithCache(this.query.sql, params, async () => { + return this.stmt.bind(...params).run(); + }); } async all(placeholderValues?: Record): Promise { @@ -188,7 +209,9 @@ export class D1PreparedQuery this.mapAllResult(results!)); + return await this.queryWithCache(query.sql, params, async () => { + return stmt.bind(...params).all().then(({ results }) => this.mapAllResult(results!)); + }); } const rows = await this.values(placeholderValues); @@ -217,7 +240,9 @@ export class D1PreparedQuery results![0]); + return await this.queryWithCache(query.sql, params, async () => { + return stmt.bind(...params).all().then(({ results }) => results![0]); + }); } const rows = await this.values(placeholderValues); @@ -249,10 +274,12 @@ export class D1PreparedQuery(placeholderValues?: Record): Promise { + async values(placeholderValues?: Record): Promise { const params = fillPlaceholders(this.query.params, placeholderValues ?? {}); this.logger.logQuery(this.query.sql, params); - return this.stmt.bind(...params).raw(); + return await this.queryWithCache(this.query.sql, params, async () => { + return this.stmt.bind(...params).raw(); + }); } /** @internal */ diff --git a/drizzle-orm/src/durable-sqlite/session.ts b/drizzle-orm/src/durable-sqlite/session.ts index dca5ce7cf2..28d6e3162f 100644 --- a/drizzle-orm/src/durable-sqlite/session.ts +++ b/drizzle-orm/src/durable-sqlite/session.ts @@ -111,7 +111,8 @@ export class SQLiteDOPreparedQuery unknown, ) { - super('sync', executeMethod, query); + // 3-6 params are for cache. As long as we don't support sync cache - it will be skipped here + super('sync', executeMethod, query, {} as any, undefined, undefined); } run(placeholderValues?: Record): void { diff --git a/drizzle-orm/src/errors/index.ts b/drizzle-orm/src/errors/index.ts new file mode 100644 index 0000000000..55f1cd949f --- /dev/null +++ b/drizzle-orm/src/errors/index.ts @@ -0,0 +1,13 @@ +export class DrizzleQueryError extends Error { + constructor( + public query: string, + public params: any[], + public override cause?: Error, + ) { + super(`Failed query: ${query}\nparams: ${params}`); + Error.captureStackTrace(this, DrizzleQueryError); + + // ES2022+: preserves original error on `.cause` + if (cause) (this as any).cause = cause; + } +} diff --git a/drizzle-orm/src/gel-core/db.ts b/drizzle-orm/src/gel-core/db.ts index 0c222a9ddd..dadcc3cc31 100644 --- a/drizzle-orm/src/gel-core/db.ts +++ b/drizzle-orm/src/gel-core/db.ts @@ -1,3 +1,4 @@ +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { GelDialect } from '~/gel-core/dialect.ts'; import { @@ -77,6 +78,8 @@ export class GelDatabase< ); } } + + this.$cache = { invalidate: async (_params: any) => {} }; } /** @@ -497,6 +500,8 @@ export class GelDatabase< }); } + $cache: { invalidate: Cache['onMutate'] }; + /** * Creates an update query. * diff --git a/drizzle-orm/src/gel-core/query-builders/delete.ts b/drizzle-orm/src/gel-core/query-builders/delete.ts index 3f5f77a0aa..cbb0ddacab 100644 --- a/drizzle-orm/src/gel-core/query-builders/delete.ts +++ b/drizzle-orm/src/gel-core/query-builders/delete.ts @@ -17,6 +17,7 @@ import { Table } from '~/table.ts'; import { tracer } from '~/tracing.ts'; import { orderSelectedFields } from '~/utils.ts'; import type { GelColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { SelectedFieldsFlat, SelectedFieldsOrdered } from './select.types.ts'; export type GelDeleteWithout< @@ -224,7 +225,10 @@ export class GelDeleteBase< PreparedQueryConfig & { execute: TReturning extends undefined ? GelQueryResultKind : TReturning[]; } - >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true); + >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true, undefined, { + type: 'delete', + tables: extractUsedTable(this.config.table), + }); }); } diff --git a/drizzle-orm/src/gel-core/query-builders/insert.ts b/drizzle-orm/src/gel-core/query-builders/insert.ts index a13bcbc495..754dbd2b10 100644 --- a/drizzle-orm/src/gel-core/query-builders/insert.ts +++ b/drizzle-orm/src/gel-core/query-builders/insert.ts @@ -21,6 +21,7 @@ import { Columns, Table } from '~/table.ts'; import { tracer } from '~/tracing.ts'; import { haveSameKeys, type NeonAuthToken, orderSelectedFields } from '~/utils.ts'; import type { AnyGelColumn, GelColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import { QueryBuilder } from './query-builder.ts'; import type { SelectedFieldsFlat, SelectedFieldsOrdered } from './select.types.ts'; import type { GelUpdateSetSource } from './update.ts'; @@ -386,7 +387,10 @@ export class GelInsertBase< PreparedQueryConfig & { execute: TReturning extends undefined ? GelQueryResultKind : TReturning[]; } - >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true); + >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true, undefined, { + type: 'insert', + tables: extractUsedTable(this.config.table), + }); }); } diff --git a/drizzle-orm/src/gel-core/query-builders/select.ts b/drizzle-orm/src/gel-core/query-builders/select.ts index 93da341ed3..204f95d7d5 100644 --- a/drizzle-orm/src/gel-core/query-builders/select.ts +++ b/drizzle-orm/src/gel-core/query-builders/select.ts @@ -1,3 +1,4 @@ +import type { CacheConfig, WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind, is } from '~/entity.ts'; import type { GelColumn } from '~/gel-core/columns/index.ts'; import type { GelDialect } from '~/gel-core/dialect.ts'; @@ -34,6 +35,7 @@ import { } from '~/utils.ts'; import { orderSelectedFields } from '~/utils.ts'; import { ViewBaseConfig } from '~/view-common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { AnyGelSelect, CreateGelSelectFromBuilderMode, @@ -166,6 +168,7 @@ export abstract class GelSelectQueryBuilderBase< readonly excludedMethods: TExcludedMethods; readonly result: TResult; readonly selectedFields: TSelectedFields; + readonly config: GelSelectConfig; }; protected config: GelSelectConfig; @@ -174,6 +177,8 @@ export abstract class GelSelectQueryBuilderBase< private isPartialSelect: boolean; protected session: GelSession | undefined; protected dialect: GelDialect; + protected cacheConfig?: WithCacheConfig = undefined; + protected usedTables: Set = new Set(); constructor( { table, fields, isPartialSelect, session, dialect, withList, distinct }: { @@ -201,9 +206,16 @@ export abstract class GelSelectQueryBuilderBase< this.dialect = dialect; this._ = { selectedFields: fields as TSelectedFields, + config: this.config, } as this['_']; this.tableName = getTableLikeName(table); this.joinsNotNullableMap = typeof this.tableName === 'string' ? { [this.tableName]: true } : {}; + for (const item of extractUsedTable(table)) this.usedTables.add(item); + } + + /** @internal */ + getUsedTables() { + return [...this.usedTables]; } private createJoin< @@ -224,6 +236,9 @@ export abstract class GelSelectQueryBuilderBase< throw new Error(`Alias "${tableName}" is already used in this query`); } + // store all tables used in a query + for (const item of extractUsedTable(table)) this.usedTables.add(item); + if (!this.isPartialSelect) { // If this is the first join and this is not a partial select and we're not selecting from raw SQL, "move" the fields from the main table to the nested object if (Object.keys(this.joinsNotNullableMap).length === 1 && typeof baseTableName === 'string') { @@ -969,8 +984,12 @@ export abstract class GelSelectQueryBuilderBase< as( alias: TAlias, ): SubqueryWithSelection { + const usedTables: string[] = []; + usedTables.push(...extractUsedTable(this.config.table)); + if (this.config.joins) { for (const it of this.config.joins) usedTables.push(...extractUsedTable(it.table)); } + return new Proxy( - new Subquery(this.getSQL(), this.config.fields, alias), + new Subquery(this.getSQL(), this.config.fields, alias, false, [...new Set(usedTables)]), new SelectionProxyHandler({ alias, sqlAliasedBehavior: 'alias', sqlBehavior: 'error' }), ) as SubqueryWithSelection; } @@ -1039,7 +1058,7 @@ export class GelSelectBase< /** @internal */ _prepare(name?: string): GelSelectPrepare { - const { session, config, dialect, joinsNotNullableMap } = this; + const { session, config, dialect, joinsNotNullableMap, cacheConfig, usedTables } = this; if (!session) { throw new Error('Cannot execute a query on a query builder. Please use a database instance instead.'); } @@ -1047,13 +1066,25 @@ export class GelSelectBase< const fieldsList = orderSelectedFields(config.fields); const query = session.prepareQuery< PreparedQueryConfig & { execute: TResult } - >(dialect.sqlToQuery(this.getSQL()), fieldsList, name, true); + >(dialect.sqlToQuery(this.getSQL()), fieldsList, name, true, undefined, { + type: 'select', + tables: [...usedTables], + }, cacheConfig); query.joinsNotNullableMap = joinsNotNullableMap; return query; }); } + $withCache(config?: { config?: CacheConfig; tag?: string; autoInvalidate?: boolean } | false) { + this.cacheConfig = config === undefined + ? { config: {}, enable: true, autoInvalidate: true } + : config === false + ? { enable: false } + : { enable: true, autoInvalidate: true, ...config }; + return this; + } + /** * Create a prepared statement for this query. This allows * the database to remember this query for the given session diff --git a/drizzle-orm/src/gel-core/query-builders/update.ts b/drizzle-orm/src/gel-core/query-builders/update.ts index 9a8057997d..3cd8d98b24 100644 --- a/drizzle-orm/src/gel-core/query-builders/update.ts +++ b/drizzle-orm/src/gel-core/query-builders/update.ts @@ -35,6 +35,7 @@ import { } from '~/utils.ts'; import { ViewBaseConfig } from '~/view-common.ts'; import type { GelColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { GelViewBase } from '../view-base.ts'; import type { GelSelectJoinConfig, SelectedFields, SelectedFieldsOrdered } from './select.types.ts'; @@ -538,7 +539,10 @@ export class GelUpdateBase< _prepare(name?: string): GelUpdatePrepare { const query = this.session.prepareQuery< PreparedQueryConfig & { execute: TReturning[] } - >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true); + >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true, undefined, { + type: 'update', + tables: extractUsedTable(this.config.table), + }); query.joinsNotNullableMap = this.joinsNotNullableMap; return query; } diff --git a/drizzle-orm/src/gel-core/session.ts b/drizzle-orm/src/gel-core/session.ts index 4033bb580b..4057e9478b 100644 --- a/drizzle-orm/src/gel-core/session.ts +++ b/drizzle-orm/src/gel-core/session.ts @@ -1,5 +1,8 @@ -import { entityKind } from '~/entity.ts'; +import { type Cache, hashQuery, NoopCache } from '~/cache/core/cache.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; +import { entityKind, is } from '~/entity.ts'; import { TransactionRollbackError } from '~/errors.ts'; +import { DrizzleQueryError } from '~/errors/index.ts'; import type { TablesRelationalConfig } from '~/relations.ts'; import type { PreparedQuery } from '~/session.ts'; import type { Query, SQL } from '~/sql/index.ts'; @@ -16,7 +19,112 @@ export interface PreparedQueryConfig { } export abstract class GelPreparedQuery implements PreparedQuery { - constructor(protected query: Query) {} + constructor( + protected query: Query, + private cache?: Cache, + // per query related metadata + private queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + // config that was passed through $withCache + private cacheConfig?: WithCacheConfig, + ) { + // it means that no $withCache options were passed and it should be just enabled + if (cache && cache.strategy() === 'all' && cacheConfig === undefined) { + this.cacheConfig = { enable: true, autoInvalidate: true }; + } + if (!this.cacheConfig?.enable) { + this.cacheConfig = undefined; + } + } + + /** @internal */ + protected async queryWithCache( + queryString: string, + params: any[], + query: () => Promise, + ): Promise { + if (this.cache === undefined || is(this.cache, NoopCache) || this.queryMetadata === undefined) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // don't do any mutations, if globally is false + if (this.cacheConfig && !this.cacheConfig.enable) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // For mutate queries, we should query the database, wait for a response, and then perform invalidation + if ( + ( + this.queryMetadata.type === 'insert' || this.queryMetadata.type === 'update' + || this.queryMetadata.type === 'delete' + ) && this.queryMetadata.tables.length > 0 + ) { + try { + const [res] = await Promise.all([ + query(), + this.cache.onMutate({ tables: this.queryMetadata.tables }), + ]); + return res; + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // don't do any reads if globally disabled + if (!this.cacheConfig) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + if (this.queryMetadata.type === 'select') { + const fromCache = await this.cache.get( + this.cacheConfig.tag ?? await hashQuery(queryString, params), + this.queryMetadata.tables, + this.cacheConfig.tag !== undefined, + this.cacheConfig.autoInvalidate, + ); + if (fromCache === undefined) { + let result; + try { + result = await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + + // put actual key + await this.cache.put( + this.cacheConfig.tag ?? await hashQuery(queryString, params), + result, + // make sure we send tables that were used in a query only if user wants to invalidate it on each write + this.cacheConfig.autoInvalidate ? this.queryMetadata.tables : [], + this.cacheConfig.tag !== undefined, + this.cacheConfig.config, + ); + // put flag if we should invalidate or not + return result; + } + + return fromCache as unknown as T; + } + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } protected authToken?: NeonAuthToken; @@ -57,6 +165,11 @@ export abstract class GelSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][], mapColumnValue?: (value: unknown) => unknown) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): GelPreparedQuery; execute(query: SQL): Promise { diff --git a/drizzle-orm/src/gel-core/utils.ts b/drizzle-orm/src/gel-core/utils.ts index 2f5b7be4b7..c638de8ba8 100644 --- a/drizzle-orm/src/gel-core/utils.ts +++ b/drizzle-orm/src/gel-core/utils.ts @@ -1,4 +1,6 @@ import { is } from '~/entity.ts'; +import { SQL } from '~/sql/sql.ts'; +import { Subquery } from '~/subquery.ts'; import { Table } from '~/table.ts'; import { ViewBaseConfig } from '~/view-common.ts'; import { type Check, CheckBuilder } from './checks.ts'; @@ -10,6 +12,7 @@ import { GelPolicy } from './policies.ts'; import { type PrimaryKey, PrimaryKeyBuilder } from './primary-keys.ts'; import { GelTable } from './table.ts'; import { type UniqueConstraint, UniqueConstraintBuilder } from './unique-constraint.ts'; +import type { GelViewBase } from './view-base.ts'; import { GelViewConfig } from './view-common.ts'; import { type GelMaterializedView, GelMaterializedViewConfig, type GelView } from './view.ts'; @@ -61,6 +64,19 @@ export function getTableConfig(table: TTable) { }; } +export function extractUsedTable(table: GelTable | Subquery | GelViewBase | SQL): string[] { + if (is(table, GelTable)) { + return [`${table[Table.Symbol.BaseName]}`]; + } + if (is(table, Subquery)) { + return table._.usedTables ?? []; + } + if (is(table, SQL)) { + return table.usedTables ?? []; + } + return []; +} + export function getViewConfig< TName extends string = string, TExisting extends boolean = boolean, diff --git a/drizzle-orm/src/gel/driver.ts b/drizzle-orm/src/gel/driver.ts index 1d5d2baa50..74952f6e32 100644 --- a/drizzle-orm/src/gel/driver.ts +++ b/drizzle-orm/src/gel/driver.ts @@ -1,4 +1,5 @@ import { type Client, type ConnectOptions, createClient } from 'gel'; +import type { Cache } from '~/cache/core/index.ts'; import { entityKind } from '~/entity.ts'; import { GelDatabase } from '~/gel-core/db.ts'; import { GelDialect } from '~/gel-core/dialect.ts'; @@ -17,6 +18,7 @@ import { GelDbSession } from './session.ts'; export interface GelDriverOptions { logger?: Logger; + cache?: Cache; } export class GelDriver { @@ -31,7 +33,10 @@ export class GelDriver { createSession( schema: RelationalSchemaConfig | undefined, ): GelDbSession, TablesRelationalConfig> { - return new GelDbSession(this.client, this.dialect, schema, { logger: this.options.logger }); + return new GelDbSession(this.client, this.dialect, schema, { + logger: this.options.logger, + cache: this.options.cache, + }); } } @@ -68,10 +73,14 @@ function construct< }; } - const driver = new GelDriver(client, dialect, { logger }); + const driver = new GelDriver(client, dialect, { logger, cache: config.cache }); const session = driver.createSession(schema); const db = new GelJsDatabase(dialect, session, schema as any) as GelJsDatabase; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/gel/session.ts b/drizzle-orm/src/gel/session.ts index db2b377737..d902541e63 100644 --- a/drizzle-orm/src/gel/session.ts +++ b/drizzle-orm/src/gel/session.ts @@ -1,5 +1,7 @@ import type { Client } from 'gel'; import type { Transaction } from 'gel/dist/transaction'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { GelDialect } from '~/gel-core/dialect.ts'; import type { SelectedFieldsOrdered } from '~/gel-core/query-builders/select.types.ts'; @@ -20,12 +22,18 @@ export class GelDbPreparedQuery extends GelPrepar private queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], private transaction: boolean = false, ) { - super({ sql: queryString, params }); + super({ sql: queryString, params }, cache, queryMetadata, cacheConfig); } async execute(placeholderValues: Record | undefined = {}): Promise { @@ -41,17 +49,21 @@ export class GelDbPreparedQuery extends GelPrepar 'drizzle.query.params': JSON.stringify(params), }); - return client.querySQL(query, params.length ? params : undefined); + return await this.queryWithCache(query, params, async () => { + return await client.querySQL(query, params.length ? params : undefined); + }); }); } - const result = (await tracer.startActiveSpan('drizzle.driver.execute', (span) => { + const result = (await tracer.startActiveSpan('drizzle.driver.execute', async (span) => { span?.setAttributes({ 'drizzle.query.text': query, 'drizzle.query.params': JSON.stringify(params), }); - return client.withSQLRowMode('array').querySQL(query, params.length ? params : undefined); + return await this.queryWithCache(query, params, async () => { + return await client.withSQLRowMode('array').querySQL(query, params.length ? params : undefined); + }); })) as unknown[][]; return tracer.startActiveSpan('drizzle.mapResponse', () => { @@ -62,18 +74,23 @@ export class GelDbPreparedQuery extends GelPrepar }); } - all(placeholderValues: Record | undefined = {}): Promise { - return tracer.startActiveSpan('drizzle.execute', () => { + async all(placeholderValues: Record | undefined = {}): Promise { + return await tracer.startActiveSpan('drizzle.execute', async () => { const params = fillPlaceholders(this.params, placeholderValues); this.logger.logQuery(this.queryString, params); - return tracer.startActiveSpan('drizzle.driver.execute', (span) => { + return await tracer.startActiveSpan('drizzle.driver.execute', async (span) => { span?.setAttributes({ 'drizzle.query.text': this.queryString, 'drizzle.query.params': JSON.stringify(params), }); - return this.client.withSQLRowMode('array').querySQL(this.queryString, params.length ? params : undefined).then(( - result, - ) => result); + return await this.queryWithCache(this.queryString, params, async () => { + return await this.client.withSQLRowMode('array').querySQL( + this.queryString, + params.length ? params : undefined, + ).then(( + result, + ) => result); + }); }); }); } @@ -86,6 +103,7 @@ export class GelDbPreparedQuery extends GelPrepar export interface GelSessionOptions { logger?: Logger; + cache?: Cache; } export class GelDbSession, TSchema extends TablesRelationalConfig> @@ -94,6 +112,7 @@ export class GelDbSession, TSchema e static override readonly [entityKind]: string = 'GelDbSession'; private logger: Logger; + private cache: Cache; constructor( private client: GelClient, @@ -103,6 +122,7 @@ export class GelDbSession, TSchema e ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -111,12 +131,20 @@ export class GelDbSession, TSchema e name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): GelDbPreparedQuery { return new GelDbPreparedQuery( this.client, query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, isResponseInArrayMode, customResultMapper, diff --git a/drizzle-orm/src/libsql/driver-core.ts b/drizzle-orm/src/libsql/driver-core.ts index 1bee47d7b4..c427637e82 100644 --- a/drizzle-orm/src/libsql/driver-core.ts +++ b/drizzle-orm/src/libsql/driver-core.ts @@ -56,9 +56,12 @@ export function construct< }; } - const session = new LibSQLSession(client, dialect, schema, { logger }, undefined); + const session = new LibSQLSession(client, dialect, schema, { logger, cache: config.cache }, undefined); const db = new LibSQLDatabase('async', dialect, session, schema) as LibSQLDatabase; ( db).$client = client; - + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/libsql/session.ts b/drizzle-orm/src/libsql/session.ts index 617ebe342e..b4c331068b 100644 --- a/drizzle-orm/src/libsql/session.ts +++ b/drizzle-orm/src/libsql/session.ts @@ -1,5 +1,7 @@ import type { Client, InArgs, InStatement, ResultSet, Transaction } from '@libsql/client'; import type { BatchItem as BatchItem } from '~/batch.ts'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -19,6 +21,7 @@ import { mapResultRow } from '~/utils.ts'; export interface LibSQLSessionOptions { logger?: Logger; + cache?: Cache; } type PreparedQueryConfig = Omit; @@ -30,6 +33,7 @@ export class LibSQLSession< static override readonly [entityKind]: string = 'LibSQLSession'; private logger: Logger; + private cache: Cache; constructor( private client: Client, @@ -40,6 +44,7 @@ export class LibSQLSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery>( @@ -48,11 +53,19 @@ export class LibSQLSession< executeMethod: SQLiteExecuteMethod, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => unknown, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): LibSQLPreparedQuery { return new LibSQLPreparedQuery( this.client, query, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, this.tx, executeMethod, @@ -158,6 +171,12 @@ export class LibSQLPreparedQuery unknown, ) => unknown, ) { - super('async', executeMethod, query); + super('async', executeMethod, query, cache, queryMetadata, cacheConfig); this.customResultMapper = customResultMapper; this.fields = fields; } - run(placeholderValues?: Record): Promise { + async run(placeholderValues?: Record): Promise { const params = fillPlaceholders(this.query.params, placeholderValues ?? {}); this.logger.logQuery(this.query.sql, params); - const stmt: InStatement = { sql: this.query.sql, args: params as InArgs }; - return this.tx ? this.tx.execute(stmt) : this.client.execute(stmt); + return await this.queryWithCache(this.query.sql, params, async () => { + const stmt: InStatement = { sql: this.query.sql, args: params as InArgs }; + return this.tx ? this.tx.execute(stmt) : this.client.execute(stmt); + }); } async all(placeholderValues?: Record): Promise { @@ -184,8 +205,10 @@ export class LibSQLPreparedQuery this.mapAllResult(rows)); + return await this.queryWithCache(query.sql, params, async () => { + const stmt: InStatement = { sql: query.sql, args: params as InArgs }; + return (tx ? tx.execute(stmt) : client.execute(stmt)).then(({ rows }) => this.mapAllResult(rows)); + }); } const rows = await this.values(placeholderValues) as unknown[][]; @@ -220,8 +243,10 @@ export class LibSQLPreparedQuery this.mapGetResult(rows)); + return await this.queryWithCache(query.sql, params, async () => { + const stmt: InStatement = { sql: query.sql, args: params as InArgs }; + return (tx ? tx.execute(stmt) : client.execute(stmt)).then(({ rows }) => this.mapGetResult(rows)); + }); } const rows = await this.values(placeholderValues) as unknown[][]; @@ -255,13 +280,15 @@ export class LibSQLPreparedQuery): Promise { + async values(placeholderValues?: Record): Promise { const params = fillPlaceholders(this.query.params, placeholderValues ?? {}); this.logger.logQuery(this.query.sql, params); - const stmt: InStatement = { sql: this.query.sql, args: params as InArgs }; - return (this.tx ? this.tx.execute(stmt) : this.client.execute(stmt)).then(({ rows }) => rows) as Promise< - T['values'] - >; + return await this.queryWithCache(this.query.sql, params, async () => { + const stmt: InStatement = { sql: this.query.sql, args: params as InArgs }; + return (this.tx ? this.tx.execute(stmt) : this.client.execute(stmt)).then(({ rows }) => rows) as Promise< + T['values'] + >; + }); } /** @internal */ diff --git a/drizzle-orm/src/mysql-core/db.ts b/drizzle-orm/src/mysql-core/db.ts index 6f79488383..78cd86a599 100644 --- a/drizzle-orm/src/mysql-core/db.ts +++ b/drizzle-orm/src/mysql-core/db.ts @@ -1,4 +1,5 @@ import type { ResultSetHeader } from 'mysql2/promise'; +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { TypedQueryBuilder } from '~/query-builders/query-builder.ts'; import type { ExtractTablesWithRelations, RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts'; @@ -85,6 +86,7 @@ export class MySqlDatabase< ); } } + this.$cache = { invalidate: async (_params: any) => {} }; } /** @@ -151,6 +153,8 @@ export class MySqlDatabase< return new MySqlCountBuilder({ source, filters, session: this.session }); } + $cache: { invalidate: Cache['onMutate'] }; + /** * Incorporates a previously defined CTE (using `$with`) into the main query. * diff --git a/drizzle-orm/src/mysql-core/query-builders/delete.ts b/drizzle-orm/src/mysql-core/query-builders/delete.ts index 22a3e1be36..8d00a0f643 100644 --- a/drizzle-orm/src/mysql-core/query-builders/delete.ts +++ b/drizzle-orm/src/mysql-core/query-builders/delete.ts @@ -17,6 +17,7 @@ import type { Subquery } from '~/subquery.ts'; import { Table } from '~/table.ts'; import type { ValueOrArray } from '~/utils.ts'; import type { MySqlColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { SelectedFieldsOrdered } from './select.types.ts'; export type MySqlDeleteWithout< @@ -185,6 +186,13 @@ export class MySqlDeleteBase< return this.session.prepareQuery( this.dialect.sqlToQuery(this.getSQL()), this.config.returning, + undefined, + undefined, + undefined, + { + type: 'delete', + tables: extractUsedTable(this.config.table), + }, ) as MySqlDeletePrepare; } diff --git a/drizzle-orm/src/mysql-core/query-builders/insert.ts b/drizzle-orm/src/mysql-core/query-builders/insert.ts index f943d03229..f2b8153ac7 100644 --- a/drizzle-orm/src/mysql-core/query-builders/insert.ts +++ b/drizzle-orm/src/mysql-core/query-builders/insert.ts @@ -1,3 +1,4 @@ +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind, is } from '~/entity.ts'; import type { MySqlDialect } from '~/mysql-core/dialect.ts'; import type { @@ -19,6 +20,7 @@ import type { InferModelFromColumns } from '~/table.ts'; import { Columns, Table } from '~/table.ts'; import { haveSameKeys, mapUpdateSet } from '~/utils.ts'; import type { AnyMySqlColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import { QueryBuilder } from './query-builder.ts'; import type { SelectedFieldsOrdered } from './select.types.ts'; import type { MySqlUpdateSetSource } from './update.ts'; @@ -228,6 +230,7 @@ export class MySqlInsertBase< declare protected $table: TTable; private config: MySqlInsertConfig; + protected cacheConfig?: WithCacheConfig; constructor( table: TTable, @@ -308,6 +311,11 @@ export class MySqlInsertBase< undefined, generatedIds, this.config.returning, + { + type: 'insert', + tables: extractUsedTable(this.config.table), + }, + this.cacheConfig, ) as MySqlInsertPrepare; } diff --git a/drizzle-orm/src/mysql-core/query-builders/select.ts b/drizzle-orm/src/mysql-core/query-builders/select.ts index 79c4e1cdbf..478c9c1822 100644 --- a/drizzle-orm/src/mysql-core/query-builders/select.ts +++ b/drizzle-orm/src/mysql-core/query-builders/select.ts @@ -1,3 +1,4 @@ +import type { CacheConfig, WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind, is } from '~/entity.ts'; import type { MySqlColumn } from '~/mysql-core/columns/index.ts'; import type { MySqlDialect } from '~/mysql-core/dialect.ts'; @@ -24,7 +25,7 @@ import type { ValueOrArray } from '~/utils.ts'; import { applyMixins, getTableColumns, getTableLikeName, haveSameKeys, orderSelectedFields } from '~/utils.ts'; import { ViewBaseConfig } from '~/view-common.ts'; import type { IndexBuilder } from '../indexes.ts'; -import { convertIndexToString, toArray } from '../utils.ts'; +import { convertIndexToString, extractUsedTable, toArray } from '../utils.ts'; import { MySqlViewBase } from '../view-base.ts'; import type { AnyMySqlSelect, @@ -175,6 +176,7 @@ export abstract class MySqlSelectQueryBuilderBase< readonly excludedMethods: TExcludedMethods; readonly result: TResult; readonly selectedFields: TSelectedFields; + readonly config: MySqlSelectConfig; }; protected config: MySqlSelectConfig; @@ -184,6 +186,8 @@ export abstract class MySqlSelectQueryBuilderBase< /** @internal */ readonly session: MySqlSession | undefined; protected dialect: MySqlDialect; + protected cacheConfig?: WithCacheConfig = undefined; + protected usedTables: Set = new Set(); constructor( { table, fields, isPartialSelect, session, dialect, withList, distinct, useIndex, forceIndex, ignoreIndex }: { @@ -215,9 +219,16 @@ export abstract class MySqlSelectQueryBuilderBase< this.dialect = dialect; this._ = { selectedFields: fields as TSelectedFields, + config: this.config, } as this['_']; this.tableName = getTableLikeName(table); this.joinsNotNullableMap = typeof this.tableName === 'string' ? { [this.tableName]: true } : {}; + for (const item of extractUsedTable(table)) this.usedTables.add(item); + } + + /** @internal */ + getUsedTables() { + return [...this.usedTables]; } private createJoin< @@ -252,6 +263,9 @@ export abstract class MySqlSelectQueryBuilderBase< const baseTableName = this.tableName; const tableName = getTableLikeName(table); + // store all tables used in a query + for (const item of extractUsedTable(table)) this.usedTables.add(item); + if (typeof tableName === 'string' && this.config.joins?.some((join) => join.alias === tableName)) { throw new Error(`Alias "${tableName}" is already used in this query`); } @@ -1024,8 +1038,12 @@ export abstract class MySqlSelectQueryBuilderBase< as( alias: TAlias, ): SubqueryWithSelection { + const usedTables: string[] = []; + usedTables.push(...extractUsedTable(this.config.table)); + if (this.config.joins) { for (const it of this.config.joins) usedTables.push(...extractUsedTable(it.table)); } + return new Proxy( - new Subquery(this.getSQL(), this.config.fields, alias), + new Subquery(this.getSQL(), this.config.fields, alias, false, [...new Set(usedTables)]), new SelectionProxyHandler({ alias, sqlAliasedBehavior: 'alias', sqlBehavior: 'error' }), ) as SubqueryWithSelection; } @@ -1041,6 +1059,15 @@ export abstract class MySqlSelectQueryBuilderBase< $dynamic(): MySqlSelectDynamic { return this as any; } + + $withCache(config?: { config?: CacheConfig; tag?: string; autoInvalidate?: boolean } | false) { + this.cacheConfig = config === undefined + ? { config: {}, enable: true, autoInvalidate: true } + : config === false + ? { enable: false } + : { enable: true, autoInvalidate: true, ...config }; + return this; + } } export interface MySqlSelectBase< @@ -1103,7 +1130,10 @@ export class MySqlSelectBase< const query = this.session.prepareQuery< MySqlPreparedQueryConfig & { execute: SelectResult[] }, TPreparedQueryHKT - >(this.dialect.sqlToQuery(this.getSQL()), fieldsList); + >(this.dialect.sqlToQuery(this.getSQL()), fieldsList, undefined, undefined, undefined, { + type: 'select', + tables: [...this.usedTables], + }, this.cacheConfig); query.joinsNotNullableMap = this.joinsNotNullableMap; return query as MySqlSelectPrepare; } diff --git a/drizzle-orm/src/mysql-core/query-builders/update.ts b/drizzle-orm/src/mysql-core/query-builders/update.ts index 7c6fd40abb..7da5bbb74a 100644 --- a/drizzle-orm/src/mysql-core/query-builders/update.ts +++ b/drizzle-orm/src/mysql-core/query-builders/update.ts @@ -1,3 +1,4 @@ +import type { WithCacheConfig } from '~/cache/core/types.ts'; import type { GetColumnData } from '~/column.ts'; import { entityKind } from '~/entity.ts'; import type { MySqlDialect } from '~/mysql-core/dialect.ts'; @@ -18,6 +19,7 @@ import type { Subquery } from '~/subquery.ts'; import { Table } from '~/table.ts'; import { mapUpdateSet, type UpdateSet, type ValueOrArray } from '~/utils.ts'; import type { MySqlColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { SelectedFieldsOrdered } from './select.types.ts'; export interface MySqlUpdateConfig { @@ -129,6 +131,7 @@ export class MySqlUpdateBase< static override readonly [entityKind]: string = 'MySqlUpdate'; private config: MySqlUpdateConfig; + protected cacheConfig?: WithCacheConfig; constructor( table: TTable, @@ -223,7 +226,15 @@ export class MySqlUpdateBase< prepare(): MySqlUpdatePrepare { return this.session.prepareQuery( this.dialect.sqlToQuery(this.getSQL()), + undefined, + undefined, + undefined, this.config.returning, + { + type: 'insert', + tables: extractUsedTable(this.config.table), + }, + this.cacheConfig, ) as MySqlUpdatePrepare; } diff --git a/drizzle-orm/src/mysql-core/session.ts b/drizzle-orm/src/mysql-core/session.ts index 326b0ad61e..9a825c3bff 100644 --- a/drizzle-orm/src/mysql-core/session.ts +++ b/drizzle-orm/src/mysql-core/session.ts @@ -1,5 +1,8 @@ -import { entityKind } from '~/entity.ts'; +import { type Cache, hashQuery, NoopCache } from '~/cache/core/cache.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; +import { entityKind, is } from '~/entity.ts'; import { TransactionRollbackError } from '~/errors.ts'; +import { DrizzleQueryError } from '~/errors/index.ts'; import type { RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts'; import { type Query, type SQL, sql } from '~/sql/sql.ts'; import type { Assume, Equal } from '~/utils.ts'; @@ -45,6 +48,112 @@ export type PreparedQueryKind< export abstract class MySqlPreparedQuery { static readonly [entityKind]: string = 'MySqlPreparedQuery'; + constructor( // cache instance + private cache: Cache, + // per query related metadata + private queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + // config that was passed through $withCache + private cacheConfig?: WithCacheConfig, + ) { + // it means that no $withCache options were passed and it should be just enabled + if (cache && cache.strategy() === 'all' && cacheConfig === undefined) { + this.cacheConfig = { enable: true, autoInvalidate: true }; + } + if (!this.cacheConfig?.enable) { + this.cacheConfig = undefined; + } + } + + /** @internal */ + protected async queryWithCache( + queryString: string, + params: any[], + query: () => Promise, + ): Promise { + if (this.cache === undefined || is(this.cache, NoopCache) || this.queryMetadata === undefined) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // don't do any mutations, if globally is false + if (this.cacheConfig && !this.cacheConfig.enable) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // For mutate queries, we should query the database, wait for a response, and then perform invalidation + if ( + ( + this.queryMetadata.type === 'insert' || this.queryMetadata.type === 'update' + || this.queryMetadata.type === 'delete' + ) && this.queryMetadata.tables.length > 0 + ) { + try { + const [res] = await Promise.all([ + query(), + this.cache.onMutate({ tables: this.queryMetadata.tables }), + ]); + return res; + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // don't do any reads if globally disabled + if (!this.cacheConfig) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + if (this.queryMetadata.type === 'select') { + const fromCache = await this.cache.get( + this.cacheConfig.tag ?? await hashQuery(queryString, params), + this.queryMetadata.tables, + this.cacheConfig.tag !== undefined, + this.cacheConfig.autoInvalidate, + ); + if (fromCache === undefined) { + let result; + try { + result = await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + + // put actual key + await this.cache.put( + this.cacheConfig.tag ?? await hashQuery(queryString, params), + result, + // make sure we send tables that were used in a query only if user wants to invalidate it on each write + this.cacheConfig.autoInvalidate ? this.queryMetadata.tables : [], + this.cacheConfig.tag !== undefined, + this.cacheConfig.config, + ); + // put flag if we should invalidate or not + return result; + } + + return fromCache as unknown as T; + } + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + /** @internal */ joinsNotNullableMap?: Record; @@ -75,6 +184,11 @@ export abstract class MySqlSession< customResultMapper?: (rows: unknown[][]) => T['execute'], generatedIds?: Record[], returningIds?: SelectedFieldsOrdered, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PreparedQueryKind; execute(query: SQL): Promise { diff --git a/drizzle-orm/src/mysql-core/utils.ts b/drizzle-orm/src/mysql-core/utils.ts index b49dd00433..db454c793a 100644 --- a/drizzle-orm/src/mysql-core/utils.ts +++ b/drizzle-orm/src/mysql-core/utils.ts @@ -1,4 +1,6 @@ import { is } from '~/entity.ts'; +import { SQL } from '~/index.ts'; +import { Subquery } from '~/subquery.ts'; import { Table } from '~/table.ts'; import { ViewBaseConfig } from '~/view-common.ts'; import type { Check } from './checks.ts'; @@ -12,9 +14,23 @@ import { PrimaryKeyBuilder } from './primary-keys.ts'; import type { IndexForHint } from './query-builders/select.ts'; import { MySqlTable } from './table.ts'; import { type UniqueConstraint, UniqueConstraintBuilder } from './unique-constraint.ts'; +import type { MySqlViewBase } from './view-base.ts'; import { MySqlViewConfig } from './view-common.ts'; import type { MySqlView } from './view.ts'; +export function extractUsedTable(table: MySqlTable | Subquery | MySqlViewBase | SQL): string[] { + if (is(table, MySqlTable)) { + return [`${table[Table.Symbol.BaseName]}`]; + } + if (is(table, Subquery)) { + return table._.usedTables ?? []; + } + if (is(table, SQL)) { + return table.usedTables ?? []; + } + return []; +} + export function getTableConfig(table: MySqlTable) { const columns = Object.values(table[MySqlTable.Symbol.Columns]); const indexes: Index[] = []; diff --git a/drizzle-orm/src/mysql-proxy/session.ts b/drizzle-orm/src/mysql-proxy/session.ts index e72875e794..452402d8c2 100644 --- a/drizzle-orm/src/mysql-proxy/session.ts +++ b/drizzle-orm/src/mysql-proxy/session.ts @@ -1,4 +1,6 @@ import type { FieldPacket, ResultSetHeader } from 'mysql2/promise'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { Column } from '~/column.ts'; import { entityKind, is } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; @@ -24,6 +26,7 @@ export type MySqlRawQueryResult = [ResultSetHeader, FieldPacket[]]; export interface MySqlRemoteSessionOptions { logger?: Logger; + cache?: Cache; } export class MySqlRemoteSession< @@ -33,6 +36,7 @@ export class MySqlRemoteSession< static override readonly [entityKind]: string = 'MySqlRemoteSession'; private logger: Logger; + private cache: Cache; constructor( private client: RemoteCallback, @@ -42,6 +46,7 @@ export class MySqlRemoteSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -50,12 +55,20 @@ export class MySqlRemoteSession< customResultMapper?: (rows: unknown[][]) => T['execute'], generatedIds?: Record[], returningIds?: SelectedFieldsOrdered, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PreparedQueryKind { return new PreparedQuery( this.client, query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, customResultMapper, generatedIds, @@ -98,6 +111,12 @@ export class PreparedQuery extends PreparedQ private queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, private customResultMapper?: (rows: unknown[][]) => T['execute'], // Keys that were used in $default and the value that was generated for them @@ -105,7 +124,7 @@ export class PreparedQuery extends PreparedQ // Keys that should be returned, it has the column with all properries + key from object private returningIds?: SelectedFieldsOrdered, ) { - super(); + super(cache, queryMetadata, cacheConfig); } async execute(placeholderValues: Record | undefined = {}): Promise { @@ -117,7 +136,9 @@ export class PreparedQuery extends PreparedQ logger.logQuery(queryString, params); if (!fields && !customResultMapper) { - const { rows: data } = await client(queryString, params, 'execute'); + const { rows: data } = await this.queryWithCache(queryString, params, async () => { + return await client(queryString, params, 'execute'); + }); const insertId = data[0].insertId as number; const affectedRows = data[0].affectedRows; @@ -148,7 +169,9 @@ export class PreparedQuery extends PreparedQ return data; } - const { rows } = await client(queryString, params, 'all'); + const { rows } = await this.queryWithCache(queryString, params, async () => { + return await client(queryString, params, 'all'); + }); if (customResultMapper) { return customResultMapper(rows); diff --git a/drizzle-orm/src/mysql2/driver.ts b/drizzle-orm/src/mysql2/driver.ts index 4ef5d25de5..3db9ea369a 100644 --- a/drizzle-orm/src/mysql2/driver.ts +++ b/drizzle-orm/src/mysql2/driver.ts @@ -1,5 +1,6 @@ import { type Connection as CallbackConnection, createPool, type Pool as CallbackPool, type PoolOptions } from 'mysql2'; import type { Connection, Pool } from 'mysql2/promise'; +import type { Cache } from '~/cache/core/index.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { DefaultLogger } from '~/logger.ts'; @@ -19,6 +20,7 @@ import { MySql2Session } from './session.ts'; export interface MySqlDriverOptions { logger?: Logger; + cache?: Cache; } export class MySql2Driver { @@ -35,7 +37,11 @@ export class MySql2Driver { schema: RelationalSchemaConfig | undefined, mode: Mode, ): MySql2Session, TablesRelationalConfig> { - return new MySql2Session(this.client, this.dialect, schema, { logger: this.options.logger, mode }); + return new MySql2Session(this.client, this.dialect, schema, { + logger: this.options.logger, + mode, + cache: this.options.cache, + }); } } @@ -92,10 +98,14 @@ function construct< const mode = config.mode ?? 'default'; - const driver = new MySql2Driver(clientForInstance as MySql2Client, dialect, { logger }); + const driver = new MySql2Driver(clientForInstance as MySql2Client, dialect, { logger, cache: config.cache }); const session = driver.createSession(schema, mode); const db = new MySql2Database(dialect, session, schema as any, mode) as MySql2Database; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/mysql2/session.ts b/drizzle-orm/src/mysql2/session.ts index 7ca21c4a63..7c0b35c9ce 100644 --- a/drizzle-orm/src/mysql2/session.ts +++ b/drizzle-orm/src/mysql2/session.ts @@ -10,6 +10,8 @@ import type { RowDataPacket, } from 'mysql2/promise'; import { once } from 'node:events'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { Column } from '~/column.ts'; import { entityKind, is } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; @@ -51,6 +53,12 @@ export class MySql2PreparedQuery extends MyS queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, private customResultMapper?: (rows: unknown[][]) => T['execute'], // Keys that were used in $default and the value that was generated for them @@ -58,7 +66,7 @@ export class MySql2PreparedQuery extends MyS // Keys that should be returned, it has the column with all properries + key from object private returningIds?: SelectedFieldsOrdered, ) { - super(); + super(cache, queryMetadata, cacheConfig); this.rawQuery = { sql: queryString, // rowsAsArray: true, @@ -89,7 +97,10 @@ export class MySql2PreparedQuery extends MyS const { fields, client, rawQuery, query, joinsNotNullableMap, customResultMapper, returningIds, generatedIds } = this; if (!fields && !customResultMapper) { - const res = await client.query(rawQuery, params); + const res = await this.queryWithCache(rawQuery.sql, params, async () => { + return await client.query(rawQuery, params); + }); + const insertId = res[0].insertId; const affectedRows = res[0].affectedRows; // for each row, I need to check keys from @@ -118,7 +129,10 @@ export class MySql2PreparedQuery extends MyS return res; } - const result = await client.query(query, params); + const result = await this.queryWithCache(query.sql, params, async () => { + return await client.query(query, params); + }); + const rows = result[0]; if (customResultMapper) { @@ -183,6 +197,7 @@ export class MySql2PreparedQuery extends MyS export interface MySql2SessionOptions { logger?: Logger; + cache?: Cache; mode: Mode; } @@ -194,6 +209,7 @@ export class MySql2Session< private logger: Logger; private mode: Mode; + private cache: Cache; constructor( private client: MySql2Client, @@ -203,6 +219,7 @@ export class MySql2Session< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); this.mode = options.mode; } @@ -212,6 +229,11 @@ export class MySql2Session< customResultMapper?: (rows: unknown[][]) => T['execute'], generatedIds?: Record[], returningIds?: SelectedFieldsOrdered, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PreparedQueryKind { // Add returningId fields // Each driver gets them from response from database @@ -220,6 +242,9 @@ export class MySql2Session< query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, customResultMapper, generatedIds, diff --git a/drizzle-orm/src/neon-http/driver.ts b/drizzle-orm/src/neon-http/driver.ts index 2efe17c7d4..53ed6e2117 100644 --- a/drizzle-orm/src/neon-http/driver.ts +++ b/drizzle-orm/src/neon-http/driver.ts @@ -1,6 +1,7 @@ import type { HTTPQueryOptions, HTTPTransactionOptions, NeonQueryFunction } from '@neondatabase/serverless'; import { neon, types } from '@neondatabase/serverless'; import type { BatchItem, BatchResponse } from '~/batch.ts'; +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { DefaultLogger } from '~/logger.ts'; @@ -13,6 +14,7 @@ import { type NeonHttpClient, type NeonHttpQueryResultHKT, NeonHttpSession } fro export interface NeonDriverOptions { logger?: Logger; + cache?: Cache; } export class NeonHttpDriver { @@ -29,7 +31,10 @@ export class NeonHttpDriver { createSession( schema: RelationalSchemaConfig | undefined, ): NeonHttpSession, TablesRelationalConfig> { - return new NeonHttpSession(this.client, this.dialect, schema, { logger: this.options.logger }); + return new NeonHttpSession(this.client, this.dialect, schema, { + logger: this.options.logger, + cache: this.options.cache, + }); } initMappers() { @@ -146,7 +151,7 @@ function construct< }; } - const driver = new NeonHttpDriver(client, dialect, { logger }); + const driver = new NeonHttpDriver(client, dialect, { logger, cache: config.cache }); const session = driver.createSession(schema); const db = new NeonHttpDatabase( @@ -155,6 +160,10 @@ function construct< schema as RelationalSchemaConfig> | undefined, ); ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/neon-http/session.ts b/drizzle-orm/src/neon-http/session.ts index 0adb85cdd5..567c464710 100644 --- a/drizzle-orm/src/neon-http/session.ts +++ b/drizzle-orm/src/neon-http/session.ts @@ -1,5 +1,7 @@ import type { FullQueryResults, NeonQueryFunction, NeonQueryPromise } from '@neondatabase/serverless'; import type { BatchItem } from '~/batch.ts'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -32,11 +34,17 @@ export class NeonHttpPreparedQuery extends PgPrep private client: NeonHttpClient, query: Query, private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], ) { - super(query); + super(query, cache, queryMetadata, cacheConfig); // `client.query` is for @neondatabase/serverless v1.0.0 and up, where the // root query function `client` is only usable as a template function; // `client` is a fallback for earlier versions @@ -58,28 +66,32 @@ export class NeonHttpPreparedQuery extends PgPrep const { fields, clientQuery, query, customResultMapper } = this; if (!fields && !customResultMapper) { - return clientQuery( + return this.queryWithCache(query.sql, params, async () => { + return clientQuery( + query.sql, + params, + token === undefined + ? rawQueryConfig + : { + ...rawQueryConfig, + authToken: token, + }, + ); + }); + } + + const result = await this.queryWithCache(query.sql, params, async () => { + return await clientQuery( query.sql, params, token === undefined - ? rawQueryConfig + ? queryConfig : { - ...rawQueryConfig, + ...queryConfig, authToken: token, }, ); - } - - const result = await clientQuery( - query.sql, - params, - token === undefined - ? queryConfig - : { - ...queryConfig, - authToken: token, - }, - ); + }); return this.mapResult(result); } @@ -131,6 +143,7 @@ export class NeonHttpPreparedQuery extends PgPrep export interface NeonHttpSessionOptions { logger?: Logger; + cache?: Cache; } export class NeonHttpSession< @@ -141,6 +154,7 @@ export class NeonHttpSession< private clientQuery: (sql: string, params: any[], opts: Record) => NeonQueryPromise; private logger: Logger; + private cache: Cache; constructor( private client: NeonHttpClient, @@ -154,6 +168,7 @@ export class NeonHttpSession< // `client` is a fallback for earlier versions this.clientQuery = (client as any).query ?? client as any; this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -162,11 +177,19 @@ export class NeonHttpSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PgPreparedQuery { return new NeonHttpPreparedQuery( this.client, query, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, isResponseInArrayMode, customResultMapper, diff --git a/drizzle-orm/src/neon-serverless/driver.ts b/drizzle-orm/src/neon-serverless/driver.ts index 35d24d5681..765732f607 100644 --- a/drizzle-orm/src/neon-serverless/driver.ts +++ b/drizzle-orm/src/neon-serverless/driver.ts @@ -1,4 +1,5 @@ import { neonConfig, Pool, type PoolConfig } from '@neondatabase/serverless'; +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { DefaultLogger } from '~/logger.ts'; @@ -16,6 +17,7 @@ import { NeonSession } from './session.ts'; export interface NeonDriverOptions { logger?: Logger; + cache?: Cache; } export class NeonDriver { @@ -31,7 +33,10 @@ export class NeonDriver { createSession( schema: RelationalSchemaConfig | undefined, ): NeonSession, TablesRelationalConfig> { - return new NeonSession(this.client, this.dialect, schema, { logger: this.options.logger }); + return new NeonSession(this.client, this.dialect, schema, { + logger: this.options.logger, + cache: this.options.cache, + }); } } @@ -71,10 +76,14 @@ function construct< }; } - const driver = new NeonDriver(client, dialect, { logger }); + const driver = new NeonDriver(client, dialect, { logger, cache: config.cache }); const session = driver.createSession(schema); const db = new NeonDatabase(dialect, session, schema as any) as NeonDatabase; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/neon-serverless/session.ts b/drizzle-orm/src/neon-serverless/session.ts index 4b12c7d2d6..bc60642a18 100644 --- a/drizzle-orm/src/neon-serverless/session.ts +++ b/drizzle-orm/src/neon-serverless/session.ts @@ -8,6 +8,8 @@ import { type QueryResultRow, types, } from '@neondatabase/serverless'; +import { type Cache, NoopCache } from '~/cache/core/cache.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -33,12 +35,18 @@ export class NeonPreparedQuery extends PgPrepared queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, name: string | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], ) { - super({ sql: queryString, params }); + super({ sql: queryString, params }, cache, queryMetadata, cacheConfig); this.rawQueryConfig = { name, text: queryString, @@ -136,10 +144,14 @@ export class NeonPreparedQuery extends PgPrepared const { fields, client, rawQueryConfig: rawQuery, queryConfig: query, joinsNotNullableMap, customResultMapper } = this; if (!fields && !customResultMapper) { - return client.query(rawQuery, params); + return await this.queryWithCache(rawQuery.text, params, async () => { + return await client.query(rawQuery, params); + }); } - const result = await client.query(query, params); + const result = await this.queryWithCache(query.text, params, async () => { + return await client.query(query, params); + }); return customResultMapper ? customResultMapper(result.rows) @@ -149,13 +161,17 @@ export class NeonPreparedQuery extends PgPrepared all(placeholderValues: Record | undefined = {}): Promise { const params = fillPlaceholders(this.params, placeholderValues); this.logger.logQuery(this.rawQueryConfig.text, params); - return this.client.query(this.rawQueryConfig, params).then((result) => result.rows); + return this.queryWithCache(this.rawQueryConfig.text, params, async () => { + return await this.client.query(this.rawQueryConfig, params); + }).then((result) => result.rows); } values(placeholderValues: Record | undefined = {}): Promise { const params = fillPlaceholders(this.params, placeholderValues); this.logger.logQuery(this.rawQueryConfig.text, params); - return this.client.query(this.queryConfig, params).then((result) => result.rows); + return this.queryWithCache(this.queryConfig.text, params, async () => { + return await this.client.query(this.queryConfig, params); + }).then((result) => result.rows); } /** @internal */ @@ -166,6 +182,7 @@ export class NeonPreparedQuery extends PgPrepared export interface NeonSessionOptions { logger?: Logger; + cache?: Cache; } export class NeonSession< @@ -175,6 +192,7 @@ export class NeonSession< static override readonly [entityKind]: string = 'NeonSession'; private logger: Logger; + private cache: Cache; constructor( private client: NeonClient, @@ -184,6 +202,7 @@ export class NeonSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -192,12 +211,20 @@ export class NeonSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PgPreparedQuery { return new NeonPreparedQuery( this.client, query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, name, isResponseInArrayMode, diff --git a/drizzle-orm/src/node-postgres/driver.ts b/drizzle-orm/src/node-postgres/driver.ts index e48d0e177e..9ea0ae2df8 100644 --- a/drizzle-orm/src/node-postgres/driver.ts +++ b/drizzle-orm/src/node-postgres/driver.ts @@ -1,4 +1,5 @@ import pg, { type Pool, type PoolConfig } from 'pg'; +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { DefaultLogger } from '~/logger.ts'; @@ -16,6 +17,7 @@ import { NodePgSession } from './session.ts'; export interface PgDriverOptions { logger?: Logger; + cache?: Cache; } export class NodePgDriver { @@ -31,7 +33,10 @@ export class NodePgDriver { createSession( schema: RelationalSchemaConfig | undefined, ): NodePgSession, TablesRelationalConfig> { - return new NodePgSession(this.client, this.dialect, schema, { logger: this.options.logger }); + return new NodePgSession(this.client, this.dialect, schema, { + logger: this.options.logger, + cache: this.options.cache, + }); } } @@ -71,10 +76,14 @@ function construct< }; } - const driver = new NodePgDriver(client, dialect, { logger }); + const driver = new NodePgDriver(client, dialect, { logger, cache: config.cache }); const session = driver.createSession(schema); const db = new NodePgDatabase(dialect, session, schema as any) as NodePgDatabase; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/node-postgres/session.ts b/drizzle-orm/src/node-postgres/session.ts index 45e4158d11..e5fb6ba7b7 100644 --- a/drizzle-orm/src/node-postgres/session.ts +++ b/drizzle-orm/src/node-postgres/session.ts @@ -1,5 +1,7 @@ import type { Client, PoolClient, QueryArrayConfig, QueryConfig, QueryResult, QueryResultRow } from 'pg'; import pg from 'pg'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import { type Logger, NoopLogger } from '~/logger.ts'; import type { PgDialect } from '~/pg-core/dialect.ts'; @@ -24,15 +26,21 @@ export class NodePgPreparedQuery extends PgPrepar constructor( private client: NodePgClient, - queryString: string, + private queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, name: string | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], ) { - super({ sql: queryString, params }); + super({ sql: queryString, params }, cache, queryMetadata, cacheConfig); this.rawQueryConfig = { name, text: queryString, @@ -137,7 +145,9 @@ export class NodePgPreparedQuery extends PgPrepar 'drizzle.query.text': rawQuery.text, 'drizzle.query.params': JSON.stringify(params), }); - return client.query(rawQuery, params); + return this.queryWithCache(rawQuery.text, params, async () => { + return await client.query(rawQuery, params); + }); }); } @@ -147,7 +157,9 @@ export class NodePgPreparedQuery extends PgPrepar 'drizzle.query.text': query.text, 'drizzle.query.params': JSON.stringify(params), }); - return client.query(query, params); + return this.queryWithCache(query.text, params, async () => { + return await client.query(query, params); + }); }); return tracer.startActiveSpan('drizzle.mapResponse', () => { @@ -168,7 +180,9 @@ export class NodePgPreparedQuery extends PgPrepar 'drizzle.query.text': this.rawQueryConfig.text, 'drizzle.query.params': JSON.stringify(params), }); - return this.client.query(this.rawQueryConfig, params).then((result) => result.rows); + return this.queryWithCache(this.rawQueryConfig.text, params, async () => { + return this.client.query(this.rawQueryConfig, params); + }).then((result) => result.rows); }); }); } @@ -181,6 +195,7 @@ export class NodePgPreparedQuery extends PgPrepar export interface NodePgSessionOptions { logger?: Logger; + cache?: Cache; } export class NodePgSession< @@ -190,6 +205,7 @@ export class NodePgSession< static override readonly [entityKind]: string = 'NodePgSession'; private logger: Logger; + private cache: Cache; constructor( private client: NodePgClient, @@ -199,6 +215,7 @@ export class NodePgSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -207,12 +224,20 @@ export class NodePgSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PgPreparedQuery { return new NodePgPreparedQuery( this.client, query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, name, isResponseInArrayMode, diff --git a/drizzle-orm/src/op-sqlite/driver.ts b/drizzle-orm/src/op-sqlite/driver.ts index 06b9d57f49..07b21b61ab 100644 --- a/drizzle-orm/src/op-sqlite/driver.ts +++ b/drizzle-orm/src/op-sqlite/driver.ts @@ -45,9 +45,13 @@ export function drizzle = Record; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/op-sqlite/session.ts b/drizzle-orm/src/op-sqlite/session.ts index c1ac630715..a8a515c42d 100644 --- a/drizzle-orm/src/op-sqlite/session.ts +++ b/drizzle-orm/src/op-sqlite/session.ts @@ -1,4 +1,6 @@ import type { OPSQLiteConnection, QueryResult } from '@op-engineering/op-sqlite'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -18,6 +20,7 @@ import { mapResultRow } from '~/utils.ts'; export interface OPSQLiteSessionOptions { logger?: Logger; + cache?: Cache; } type PreparedQueryConfig = Omit; @@ -29,6 +32,7 @@ export class OPSQLiteSession< static override readonly [entityKind]: string = 'OPSQLiteSession'; private logger: Logger; + private cache: Cache; constructor( private client: OPSQLiteConnection, @@ -38,6 +42,7 @@ export class OPSQLiteSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery>( @@ -46,11 +51,19 @@ export class OPSQLiteSession< executeMethod: SQLiteExecuteMethod, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => unknown, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): OPSQLitePreparedQuery { return new OPSQLitePreparedQuery( this.client, query, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, executeMethod, isResponseInArrayMode, @@ -105,19 +118,27 @@ export class OPSQLitePreparedQuery unknown, ) { - super('sync', executeMethod, query); + super('sync', executeMethod, query, cache, queryMetadata, cacheConfig); } - run(placeholderValues?: Record): Promise { + async run(placeholderValues?: Record): Promise { const params = fillPlaceholders(this.query.params, placeholderValues ?? {}); this.logger.logQuery(this.query.sql, params); - return this.client.executeAsync(this.query.sql, params); + return await this.queryWithCache(this.query.sql, params, async () => { + return this.client.executeAsync(this.query.sql, params); + }); } async all(placeholderValues?: Record): Promise { @@ -126,7 +147,9 @@ export class OPSQLitePreparedQuery { + return client.execute(query.sql, params).rows?._array || []; + }); } const rows = await this.values(placeholderValues) as unknown[][]; @@ -141,7 +164,9 @@ export class OPSQLitePreparedQuery { + return client.execute(query.sql, params).rows?._array || []; + }); return rows[0]; } @@ -159,10 +184,12 @@ export class OPSQLitePreparedQuery): Promise { + async values(placeholderValues?: Record): Promise { const params = fillPlaceholders(this.query.params, placeholderValues ?? {}); this.logger.logQuery(this.query.sql, params); - return this.client.executeRawAsync(this.query.sql, params); + return await this.queryWithCache(this.query.sql, params, async () => { + return await this.client.executeRawAsync(this.query.sql, params); + }); } /** @internal */ diff --git a/drizzle-orm/src/pg-core/db.ts b/drizzle-orm/src/pg-core/db.ts index 17d8828951..cc9508d2e3 100644 --- a/drizzle-orm/src/pg-core/db.ts +++ b/drizzle-orm/src/pg-core/db.ts @@ -1,3 +1,4 @@ +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { PgDialect } from '~/pg-core/dialect.ts'; import { @@ -86,6 +87,7 @@ export class PgDatabase< ); } } + this.$cache = { invalidate: async (_params: any) => {} }; } /** @@ -152,6 +154,8 @@ export class PgDatabase< return new PgCountBuilder({ source, filters, session: this.session }); } + $cache: { invalidate: Cache['onMutate'] }; + /** * Incorporates a previously defined CTE (using `$with`) into the main query. * diff --git a/drizzle-orm/src/pg-core/query-builders/delete.ts b/drizzle-orm/src/pg-core/query-builders/delete.ts index e37c06038b..d6cc150d56 100644 --- a/drizzle-orm/src/pg-core/query-builders/delete.ts +++ b/drizzle-orm/src/pg-core/query-builders/delete.ts @@ -1,3 +1,4 @@ +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { PgDialect } from '~/pg-core/dialect.ts'; import type { @@ -8,7 +9,7 @@ import type { PreparedQueryConfig, } from '~/pg-core/session.ts'; import type { PgTable } from '~/pg-core/table.ts'; -import { TypedQueryBuilder } from '~/query-builders/query-builder.ts'; +import type { TypedQueryBuilder } from '~/query-builders/query-builder.ts'; import type { SelectResultFields } from '~/query-builders/select.types.ts'; import { QueryPromise } from '~/query-promise.ts'; import type { RunnableQuery } from '~/runnable-query.ts'; @@ -19,6 +20,7 @@ import { getTableName, Table } from '~/table.ts'; import { tracer } from '~/tracing.ts'; import { type NeonAuthToken, orderSelectedFields } from '~/utils.ts'; import type { PgColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { SelectedFieldsFlat, SelectedFieldsOrdered } from './select.types.ts'; export type PgDeleteWithout< @@ -150,6 +152,7 @@ export class PgDeleteBase< static override readonly [entityKind]: string = 'PgDelete'; private config: PgDeleteConfig; + protected cacheConfig?: WithCacheConfig; constructor( table: TTable, @@ -244,7 +247,10 @@ export class PgDeleteBase< PreparedQueryConfig & { execute: TReturning extends undefined ? PgQueryResultKind : TReturning[]; } - >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true); + >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true, undefined, { + type: 'delete', + tables: extractUsedTable(this.config.table), + }, this.cacheConfig); }); } diff --git a/drizzle-orm/src/pg-core/query-builders/insert.ts b/drizzle-orm/src/pg-core/query-builders/insert.ts index 5a61e9ed48..28978cc841 100644 --- a/drizzle-orm/src/pg-core/query-builders/insert.ts +++ b/drizzle-orm/src/pg-core/query-builders/insert.ts @@ -1,3 +1,4 @@ +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind, is } from '~/entity.ts'; import type { PgDialect } from '~/pg-core/dialect.ts'; import type { IndexColumn } from '~/pg-core/indexes.ts'; @@ -22,6 +23,7 @@ import { Columns, getTableName, Table } from '~/table.ts'; import { tracer } from '~/tracing.ts'; import { haveSameKeys, mapUpdateSet, type NeonAuthToken, orderSelectedFields } from '~/utils.ts'; import type { AnyPgColumn, PgColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import { QueryBuilder } from './query-builder.ts'; import type { SelectedFieldsFlat, SelectedFieldsOrdered } from './select.types.ts'; import type { PgUpdateSetSource } from './update.ts'; @@ -249,6 +251,7 @@ export class PgInsertBase< static override readonly [entityKind]: string = 'PgInsert'; private config: PgInsertConfig; + protected cacheConfig?: WithCacheConfig; constructor( table: TTable, @@ -402,7 +405,10 @@ export class PgInsertBase< PreparedQueryConfig & { execute: TReturning extends undefined ? PgQueryResultKind : TReturning[]; } - >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true); + >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true, undefined, { + type: 'insert', + tables: extractUsedTable(this.config.table), + }, this.cacheConfig); }); } diff --git a/drizzle-orm/src/pg-core/query-builders/select.ts b/drizzle-orm/src/pg-core/query-builders/select.ts index cd8e69bab1..dac94cb5c4 100644 --- a/drizzle-orm/src/pg-core/query-builders/select.ts +++ b/drizzle-orm/src/pg-core/query-builders/select.ts @@ -1,3 +1,4 @@ +import type { CacheConfig, WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind, is } from '~/entity.ts'; import type { PgColumn } from '~/pg-core/columns/index.ts'; import type { PgDialect } from '~/pg-core/dialect.ts'; @@ -35,6 +36,7 @@ import { } from '~/utils.ts'; import { orderSelectedFields } from '~/utils.ts'; import { ViewBaseConfig } from '~/view-common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { AnyPgSelect, CreatePgSelectFromBuilderMode, @@ -172,14 +174,17 @@ export abstract class PgSelectQueryBuilderBase< readonly excludedMethods: TExcludedMethods; readonly result: TResult; readonly selectedFields: TSelectedFields; + readonly config: PgSelectConfig; }; protected config: PgSelectConfig; protected joinsNotNullableMap: Record; - private tableName: string | undefined; + protected tableName: string | undefined; private isPartialSelect: boolean; protected session: PgSession | undefined; protected dialect: PgDialect; + protected cacheConfig?: WithCacheConfig = undefined; + protected usedTables: Set = new Set(); constructor( { table, fields, isPartialSelect, session, dialect, withList, distinct }: { @@ -207,9 +212,17 @@ export abstract class PgSelectQueryBuilderBase< this.dialect = dialect; this._ = { selectedFields: fields as TSelectedFields, + config: this.config, } as this['_']; this.tableName = getTableLikeName(table); this.joinsNotNullableMap = typeof this.tableName === 'string' ? { [this.tableName]: true } : {}; + + for (const item of extractUsedTable(table)) this.usedTables.add(item); + } + + /** @internal */ + getUsedTables() { + return [...this.usedTables]; } private createJoin< @@ -226,6 +239,9 @@ export abstract class PgSelectQueryBuilderBase< const baseTableName = this.tableName; const tableName = getTableLikeName(table); + // store all tables used in a query + for (const item of extractUsedTable(table)) this.usedTables.add(item); + if (typeof tableName === 'string' && this.config.joins?.some((join) => join.alias === tableName)) { throw new Error(`Alias "${tableName}" is already used in this query`); } @@ -971,12 +987,15 @@ export abstract class PgSelectQueryBuilderBase< const { typings: _typings, ...rest } = this.dialect.sqlToQuery(this.getSQL()); return rest; } - as( alias: TAlias, ): SubqueryWithSelection { + const usedTables: string[] = []; + usedTables.push(...extractUsedTable(this.config.table)); + if (this.config.joins) { for (const it of this.config.joins) usedTables.push(...extractUsedTable(it.table)); } + return new Proxy( - new Subquery(this.getSQL(), this.config.fields, alias), + new Subquery(this.getSQL(), this.config.fields, alias, false, [...new Set(usedTables)]), new SelectionProxyHandler({ alias, sqlAliasedBehavior: 'alias', sqlBehavior: 'error' }), ) as SubqueryWithSelection; } @@ -992,6 +1011,15 @@ export abstract class PgSelectQueryBuilderBase< $dynamic(): PgSelectDynamic { return this; } + + $withCache(config?: { config?: CacheConfig; tag?: string; autoInvalidate?: boolean } | false) { + this.cacheConfig = config === undefined + ? { config: {}, enable: true, autoInvalidate: true } + : config === false + ? { enable: false } + : { enable: true, autoInvalidate: true, ...config }; + return this; + } } export interface PgSelectBase< @@ -1045,15 +1073,21 @@ export class PgSelectBase< /** @internal */ _prepare(name?: string): PgSelectPrepare { - const { session, config, dialect, joinsNotNullableMap, authToken } = this; + const { session, config, dialect, joinsNotNullableMap, authToken, cacheConfig, usedTables } = this; if (!session) { throw new Error('Cannot execute a query on a query builder. Please use a database instance instead.'); } + + const { fields } = config; + return tracer.startActiveSpan('drizzle.prepareQuery', () => { - const fieldsList = orderSelectedFields(config.fields); + const fieldsList = orderSelectedFields(fields); const query = session.prepareQuery< PreparedQueryConfig & { execute: TResult } - >(dialect.sqlToQuery(this.getSQL()), fieldsList, name, true); + >(dialect.sqlToQuery(this.getSQL()), fieldsList, name, true, undefined, { + type: 'select', + tables: [...usedTables], + }, cacheConfig); query.joinsNotNullableMap = joinsNotNullableMap; return query.setToken(authToken); diff --git a/drizzle-orm/src/pg-core/query-builders/update.ts b/drizzle-orm/src/pg-core/query-builders/update.ts index 419a8aec8b..f9e366fbc6 100644 --- a/drizzle-orm/src/pg-core/query-builders/update.ts +++ b/drizzle-orm/src/pg-core/query-builders/update.ts @@ -1,3 +1,4 @@ +import type { WithCacheConfig } from '~/cache/core/types.ts'; import type { GetColumnData } from '~/column.ts'; import { entityKind, is } from '~/entity.ts'; import type { PgDialect } from '~/pg-core/dialect.ts'; @@ -9,7 +10,7 @@ import type { PreparedQueryConfig, } from '~/pg-core/session.ts'; import { PgTable } from '~/pg-core/table.ts'; -import { TypedQueryBuilder } from '~/query-builders/query-builder.ts'; +import type { TypedQueryBuilder } from '~/query-builders/query-builder.ts'; import type { AppendToNullabilityMap, AppendToResult, @@ -28,17 +29,18 @@ import { Subquery } from '~/subquery.ts'; import { getTableName, Table } from '~/table.ts'; import { type Assume, - DrizzleTypeError, - Equal, + type DrizzleTypeError, + type Equal, getTableLikeName, mapUpdateSet, type NeonAuthToken, orderSelectedFields, - Simplify, + type Simplify, type UpdateSet, } from '~/utils.ts'; import { ViewBaseConfig } from '~/view-common.ts'; import type { PgColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { PgViewBase } from '../view-base.ts'; import type { PgSelectJoinConfig, @@ -340,6 +342,7 @@ export class PgUpdateBase< TTable extends PgTable, TQueryResult extends PgQueryResultHKT, TFrom extends PgTable | Subquery | PgViewBase | SQL | undefined = undefined, + // eslint-disable-next-line @typescript-eslint/no-unused-vars TSelectedFields extends ColumnsSelection | undefined = undefined, TReturning extends Record | undefined = undefined, // eslint-disable-next-line @typescript-eslint/no-unused-vars @@ -360,6 +363,7 @@ export class PgUpdateBase< private config: PgUpdateConfig; private tableName: string | undefined; private joinsNotNullableMap: Record; + protected cacheConfig?: WithCacheConfig; constructor( table: TTable, @@ -576,7 +580,10 @@ export class PgUpdateBase< _prepare(name?: string): PgUpdatePrepare { const query = this.session.prepareQuery< PreparedQueryConfig & { execute: TReturning[] } - >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true); + >(this.dialect.sqlToQuery(this.getSQL()), this.config.returning, name, true, undefined, { + type: 'insert', + tables: extractUsedTable(this.config.table), + }, this.cacheConfig); query.joinsNotNullableMap = this.joinsNotNullableMap; return query; } diff --git a/drizzle-orm/src/pg-core/session.ts b/drizzle-orm/src/pg-core/session.ts index d77f2c4dbf..339fe75e7d 100644 --- a/drizzle-orm/src/pg-core/session.ts +++ b/drizzle-orm/src/pg-core/session.ts @@ -1,5 +1,8 @@ -import { entityKind } from '~/entity.ts'; +import { type Cache, hashQuery, NoopCache } from '~/cache/core/cache.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; +import { entityKind, is } from '~/entity.ts'; import { TransactionRollbackError } from '~/errors.ts'; +import { DrizzleQueryError } from '~/errors/index.ts'; import type { TablesRelationalConfig } from '~/relations.ts'; import type { PreparedQuery } from '~/session.ts'; import { type Query, type SQL, sql } from '~/sql/index.ts'; @@ -16,7 +19,26 @@ export interface PreparedQueryConfig { } export abstract class PgPreparedQuery implements PreparedQuery { - constructor(protected query: Query) {} + constructor( + protected query: Query, + // cache instance + private cache: Cache, + // per query related metadata + private queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + // config that was passed through $withCache + private cacheConfig?: WithCacheConfig, + ) { + // it means that no $withCache options were passed and it should be just enabled + if (cache && cache.strategy() === 'all' && cacheConfig === undefined) { + this.cacheConfig = { enable: true, autoInvalidate: true }; + } + if (!this.cacheConfig?.enable) { + this.cacheConfig = undefined; + } + } protected authToken?: NeonAuthToken; @@ -39,6 +61,92 @@ export abstract class PgPreparedQuery implements /** @internal */ joinsNotNullableMap?: Record; + /** @internal */ + protected async queryWithCache( + queryString: string, + params: any[], + query: () => Promise, + ): Promise { + if (this.cache === undefined || is(this.cache, NoopCache) || this.queryMetadata === undefined) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // don't do any mutations, if globally is false + if (this.cacheConfig && !this.cacheConfig.enable) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // For mutate queries, we should query the database, wait for a response, and then perform invalidation + if ( + ( + this.queryMetadata.type === 'insert' || this.queryMetadata.type === 'update' + || this.queryMetadata.type === 'delete' + ) && this.queryMetadata.tables.length > 0 + ) { + try { + const [res] = await Promise.all([ + query(), + this.cache.onMutate({ tables: this.queryMetadata.tables }), + ]); + return res; + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // don't do any reads if globally disabled + if (!this.cacheConfig) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + if (this.queryMetadata.type === 'select') { + const fromCache = await this.cache.get( + this.cacheConfig.tag ?? await hashQuery(queryString, params), + this.queryMetadata.tables, + this.cacheConfig.tag !== undefined, + this.cacheConfig.autoInvalidate, + ); + if (fromCache === undefined) { + let result; + try { + result = await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + // put actual key + await this.cache.put( + this.cacheConfig.tag ?? await hashQuery(queryString, params), + result, + // make sure we send tables that were used in a query only if user wants to invalidate it on each write + this.cacheConfig.autoInvalidate ? this.queryMetadata.tables : [], + this.cacheConfig.tag !== undefined, + this.cacheConfig.config, + ); + // put flag if we should invalidate or not + return result; + } + + return fromCache as unknown as T; + } + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + abstract execute(placeholderValues?: Record): Promise; /** @internal */ abstract execute(placeholderValues?: Record, token?: NeonAuthToken): Promise; @@ -73,6 +181,11 @@ export abstract class PgSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][], mapColumnValue?: (value: unknown) => unknown) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PgPreparedQuery; execute(query: SQL): Promise; diff --git a/drizzle-orm/src/pg-core/utils.ts b/drizzle-orm/src/pg-core/utils.ts index 0191c2439a..b7547ad664 100644 --- a/drizzle-orm/src/pg-core/utils.ts +++ b/drizzle-orm/src/pg-core/utils.ts @@ -1,6 +1,8 @@ import { is } from '~/entity.ts'; import { PgTable } from '~/pg-core/table.ts'; -import { Table } from '~/table.ts'; +import { SQL } from '~/sql/sql.ts'; +import { Subquery } from '~/subquery.ts'; +import { Schema, Table } from '~/table.ts'; import { ViewBaseConfig } from '~/view-common.ts'; import { type Check, CheckBuilder } from './checks.ts'; import type { AnyPgColumn } from './columns/index.ts'; @@ -10,6 +12,7 @@ import { IndexBuilder } from './indexes.ts'; import { PgPolicy } from './policies.ts'; import { type PrimaryKey, PrimaryKeyBuilder } from './primary-keys.ts'; import { type UniqueConstraint, UniqueConstraintBuilder } from './unique-constraint.ts'; +import type { PgViewBase } from './view-base.ts'; import { PgViewConfig } from './view-common.ts'; import { type PgMaterializedView, PgMaterializedViewConfig, type PgView } from './view.ts'; @@ -61,6 +64,19 @@ export function getTableConfig(table: TTable) { }; } +export function extractUsedTable(table: PgTable | Subquery | PgViewBase | SQL): string[] { + if (is(table, PgTable)) { + return [table[Schema] ? `${table[Schema]}.${table[Table.Symbol.BaseName]}` : table[Table.Symbol.BaseName]]; + } + if (is(table, Subquery)) { + return table._.usedTables ?? []; + } + if (is(table, SQL)) { + return table.usedTables ?? []; + } + return []; +} + export function getViewConfig< TName extends string = string, TExisting extends boolean = boolean, diff --git a/drizzle-orm/src/pg-proxy/driver.ts b/drizzle-orm/src/pg-proxy/driver.ts index 955dc2bb46..6016dd34cd 100644 --- a/drizzle-orm/src/pg-proxy/driver.ts +++ b/drizzle-orm/src/pg-proxy/driver.ts @@ -50,6 +50,12 @@ export function drizzle = Record; + const session = new PgRemoteSession(callback, dialect, schema, { logger, cache: config.cache }); + const db = new PgRemoteDatabase(dialect, session, schema as any) as PgRemoteDatabase; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } + + return db; } diff --git a/drizzle-orm/src/pg-proxy/session.ts b/drizzle-orm/src/pg-proxy/session.ts index 9d433502c0..5f098a8d25 100644 --- a/drizzle-orm/src/pg-proxy/session.ts +++ b/drizzle-orm/src/pg-proxy/session.ts @@ -1,3 +1,5 @@ +import { type Cache, NoopCache } from '~/cache/core/cache.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -15,6 +17,7 @@ import type { RemoteCallback } from './driver.ts'; export interface PgRemoteSessionOptions { logger?: Logger; + cache?: Cache; } export class PgRemoteSession< @@ -24,6 +27,7 @@ export class PgRemoteSession< static override readonly [entityKind]: string = 'PgRemoteSession'; private logger: Logger; + private cache: Cache; constructor( private client: RemoteCallback, @@ -33,6 +37,7 @@ export class PgRemoteSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -41,6 +46,11 @@ export class PgRemoteSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PreparedQuery { return new PreparedQuery( this.client, @@ -48,6 +58,9 @@ export class PgRemoteSession< query.params, query.typings, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, isResponseInArrayMode, customResultMapper, @@ -84,11 +97,17 @@ export class PreparedQuery extends PreparedQueryB private params: unknown[], private typings: any[] | undefined, private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], ) { - super({ sql: queryString, params }); + super({ sql: queryString, params }, cache, queryMetadata, cacheConfig); } async execute(placeholderValues: Record | undefined = {}): Promise { @@ -105,7 +124,9 @@ export class PreparedQuery extends PreparedQueryB if (!fields && !customResultMapper) { return tracer.startActiveSpan('drizzle.driver.execute', async () => { - const { rows } = await client(queryString, params as any[], 'execute', typings); + const { rows } = await this.queryWithCache(queryString, params, async () => { + return await client(queryString, params as any[], 'execute', typings); + }); return rows; }); @@ -117,7 +138,9 @@ export class PreparedQuery extends PreparedQueryB 'drizzle.query.params': JSON.stringify(params), }); - const { rows } = await client(queryString, params as any[], 'all', typings); + const { rows } = await this.queryWithCache(queryString, params, async () => { + return await client(queryString, params as any[], 'all', typings); + }); return rows; }); diff --git a/drizzle-orm/src/pglite/driver.ts b/drizzle-orm/src/pglite/driver.ts index 7f35b27795..9d5f52a34b 100644 --- a/drizzle-orm/src/pglite/driver.ts +++ b/drizzle-orm/src/pglite/driver.ts @@ -1,4 +1,5 @@ import { PGlite, type PGliteOptions } from '@electric-sql/pglite'; +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { DefaultLogger } from '~/logger.ts'; @@ -16,6 +17,7 @@ import { PgliteSession } from './session.ts'; export interface PgDriverOptions { logger?: Logger; + cache?: Cache; } export class PgliteDriver { @@ -31,7 +33,10 @@ export class PgliteDriver { createSession( schema: RelationalSchemaConfig | undefined, ): PgliteSession, TablesRelationalConfig> { - return new PgliteSession(this.client, this.dialect, schema, { logger: this.options.logger }); + return new PgliteSession(this.client, this.dialect, schema, { + logger: this.options.logger, + cache: this.options.cache, + }); } } @@ -68,10 +73,24 @@ function construct = Record; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } + // ( db).$cache = { invalidate: ( config).cache?.onMutate }; + // if (config.cache) { + // for ( + // const key of Object.getOwnPropertyNames(Object.getPrototypeOf(config.cache)).filter((key) => + // key !== 'constructor' + // ) + // ) { + // ( db).$cache[key as keyof typeof config.cache] = ( config).cache[key]; + // } + // } return db as any; } diff --git a/drizzle-orm/src/pglite/session.ts b/drizzle-orm/src/pglite/session.ts index 72126deb4c..7e15ceff30 100644 --- a/drizzle-orm/src/pglite/session.ts +++ b/drizzle-orm/src/pglite/session.ts @@ -11,6 +11,8 @@ import { fillPlaceholders, type Query, type SQL, sql } from '~/sql/sql.ts'; import { type Assume, mapResultRow } from '~/utils.ts'; import { types } from '@electric-sql/pglite'; +import { type Cache, NoopCache } from '~/cache/core/cache.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; export type PgliteClient = PGlite; @@ -25,12 +27,18 @@ export class PglitePreparedQuery extends PgPrepar private queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, name: string | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], ) { - super({ sql: queryString, params }); + super({ sql: queryString, params }, cache, queryMetadata, cacheConfig); this.rawQueryConfig = { rowMode: 'object', parsers: { @@ -76,13 +84,17 @@ export class PglitePreparedQuery extends PgPrepar this.logger.logQuery(this.queryString, params); - const { fields, rawQueryConfig, client, queryConfig, joinsNotNullableMap, customResultMapper, queryString } = this; + const { fields, client, queryConfig, joinsNotNullableMap, customResultMapper, queryString, rawQueryConfig } = this; if (!fields && !customResultMapper) { - return client.query(queryString, params, rawQueryConfig); + return this.queryWithCache(queryString, params, async () => { + return await client.query(queryString, params, rawQueryConfig); + }); } - const result = await client.query(queryString, params, queryConfig); + const result = await this.queryWithCache(queryString, params, async () => { + return await client.query(queryString, params, queryConfig); + }); return customResultMapper ? customResultMapper(result.rows) @@ -92,7 +104,9 @@ export class PglitePreparedQuery extends PgPrepar all(placeholderValues: Record | undefined = {}): Promise { const params = fillPlaceholders(this.params, placeholderValues); this.logger.logQuery(this.queryString, params); - return this.client.query(this.queryString, params, this.rawQueryConfig).then((result) => result.rows); + return this.queryWithCache(this.queryString, params, async () => { + return await this.client.query(this.queryString, params, this.rawQueryConfig); + }).then((result) => result.rows); } /** @internal */ @@ -103,6 +117,7 @@ export class PglitePreparedQuery extends PgPrepar export interface PgliteSessionOptions { logger?: Logger; + cache?: Cache; } export class PgliteSession< @@ -112,6 +127,7 @@ export class PgliteSession< static override readonly [entityKind]: string = 'PgliteSession'; private logger: Logger; + private cache: Cache; constructor( private client: PgliteClient | Transaction, @@ -121,6 +137,7 @@ export class PgliteSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -129,12 +146,20 @@ export class PgliteSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PgPreparedQuery { return new PglitePreparedQuery( this.client, query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, name, isResponseInArrayMode, diff --git a/drizzle-orm/src/planetscale-serverless/driver.ts b/drizzle-orm/src/planetscale-serverless/driver.ts index 1ea8825cb9..c03086f91c 100644 --- a/drizzle-orm/src/planetscale-serverless/driver.ts +++ b/drizzle-orm/src/planetscale-serverless/driver.ts @@ -17,6 +17,7 @@ import { PlanetscaleSession } from './session.ts'; export interface PlanetscaleSDriverOptions { logger?: Logger; + cache?: Cache; } export class PlanetScaleDatabase< @@ -72,9 +73,13 @@ const db = drizzle(client); }; } - const session = new PlanetscaleSession(client, dialect, undefined, schema, { logger }); + const session = new PlanetscaleSession(client, dialect, undefined, schema, { logger, cache: config.cache }); const db = new PlanetScaleDatabase(dialect, session, schema as any, 'planetscale') as PlanetScaleDatabase; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/planetscale-serverless/session.ts b/drizzle-orm/src/planetscale-serverless/session.ts index 272332a3a2..471aed4a28 100644 --- a/drizzle-orm/src/planetscale-serverless/session.ts +++ b/drizzle-orm/src/planetscale-serverless/session.ts @@ -1,4 +1,6 @@ import type { Client, Connection, ExecutedQuery, Transaction } from '@planetscale/database'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { Column } from '~/column.ts'; import { entityKind, is } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; @@ -28,6 +30,12 @@ export class PlanetScalePreparedQuery extend private queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, private customResultMapper?: (rows: unknown[][]) => T['execute'], // Keys that were used in $default and the value that was generated for them @@ -35,7 +43,7 @@ export class PlanetScalePreparedQuery extend // Keys that should be returned, it has the column with all properries + key from object private returningIds?: SelectedFieldsOrdered, ) { - super(); + super(cache, queryMetadata, cacheConfig); } async execute(placeholderValues: Record | undefined = {}): Promise { @@ -55,7 +63,9 @@ export class PlanetScalePreparedQuery extend generatedIds, } = this; if (!fields && !customResultMapper) { - const res = await client.execute(queryString, params, rawQuery); + const res = await this.queryWithCache(queryString, params, async () => { + return await client.execute(queryString, params, rawQuery); + }); const insertId = Number.parseFloat(res.insertId); const affectedRows = res.rowsAffected; @@ -84,7 +94,9 @@ export class PlanetScalePreparedQuery extend } return res; } - const { rows } = await client.execute(queryString, params, query); + const { rows } = await this.queryWithCache(queryString, params, async () => { + return await client.execute(queryString, params, query); + }); if (customResultMapper) { return customResultMapper(rows as unknown[][]); @@ -100,6 +112,7 @@ export class PlanetScalePreparedQuery extend export interface PlanetscaleSessionOptions { logger?: Logger; + cache?: Cache; } export class PlanetscaleSession< @@ -110,6 +123,7 @@ export class PlanetscaleSession< private logger: Logger; private client: Client | Transaction | Connection; + private cache: Cache; constructor( private baseClient: Client | Connection, @@ -121,6 +135,7 @@ export class PlanetscaleSession< super(dialect); this.client = tx ?? baseClient; this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -129,12 +144,20 @@ export class PlanetscaleSession< customResultMapper?: (rows: unknown[][]) => T['execute'], generatedIds?: Record[], returningIds?: SelectedFieldsOrdered, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): MySqlPreparedQuery { return new PlanetScalePreparedQuery( this.client, query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, customResultMapper, generatedIds, diff --git a/drizzle-orm/src/postgres-js/driver.ts b/drizzle-orm/src/postgres-js/driver.ts index 77bb815d40..411df042da 100644 --- a/drizzle-orm/src/postgres-js/driver.ts +++ b/drizzle-orm/src/postgres-js/driver.ts @@ -56,9 +56,13 @@ function construct = Record; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/postgres-js/session.ts b/drizzle-orm/src/postgres-js/session.ts index 7509e2a002..3673dd8b30 100644 --- a/drizzle-orm/src/postgres-js/session.ts +++ b/drizzle-orm/src/postgres-js/session.ts @@ -1,4 +1,6 @@ import type { Row, RowList, Sql, TransactionSql } from 'postgres'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -20,11 +22,17 @@ export class PostgresJsPreparedQuery extends PgPr private queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], ) { - super({ sql: queryString, params }); + super({ sql: queryString, params }, cache, queryMetadata, cacheConfig); } async execute(placeholderValues: Record | undefined = {}): Promise { @@ -41,7 +49,9 @@ export class PostgresJsPreparedQuery extends PgPr const { fields, queryString: query, client, joinsNotNullableMap, customResultMapper } = this; if (!fields && !customResultMapper) { return tracer.startActiveSpan('drizzle.driver.execute', () => { - return client.unsafe(query, params as any[]); + return this.queryWithCache(query, params, async () => { + return await client.unsafe(query, params as any[]); + }); }); } @@ -50,8 +60,9 @@ export class PostgresJsPreparedQuery extends PgPr 'drizzle.query.text': query, 'drizzle.query.params': JSON.stringify(params), }); - - return client.unsafe(query, params as any[]).values(); + return this.queryWithCache(query, params, async () => { + return await client.unsafe(query, params as any[]).values(); + }); }); return tracer.startActiveSpan('drizzle.mapResponse', () => { @@ -75,7 +86,9 @@ export class PostgresJsPreparedQuery extends PgPr 'drizzle.query.text': this.queryString, 'drizzle.query.params': JSON.stringify(params), }); - return this.client.unsafe(this.queryString, params as any[]); + return this.queryWithCache(this.queryString, params, async () => { + return this.client.unsafe(this.queryString, params as any[]); + }); }); }); } @@ -88,6 +101,7 @@ export class PostgresJsPreparedQuery extends PgPr export interface PostgresJsSessionOptions { logger?: Logger; + cache?: Cache; } export class PostgresJsSession< @@ -98,6 +112,7 @@ export class PostgresJsSession< static override readonly [entityKind]: string = 'PostgresJsSession'; logger: Logger; + private cache: Cache; constructor( public client: TSQL, @@ -108,6 +123,7 @@ export class PostgresJsSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -116,12 +132,20 @@ export class PostgresJsSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PgPreparedQuery { return new PostgresJsPreparedQuery( this.client, query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, isResponseInArrayMode, customResultMapper, diff --git a/drizzle-orm/src/prisma/mysql/session.ts b/drizzle-orm/src/prisma/mysql/session.ts index fc3807bc5a..6cb6d1e4dd 100644 --- a/drizzle-orm/src/prisma/mysql/session.ts +++ b/drizzle-orm/src/prisma/mysql/session.ts @@ -26,7 +26,7 @@ export class PrismaMySqlPreparedQuery extends MySqlPreparedQuery): Promise { diff --git a/drizzle-orm/src/prisma/pg/session.ts b/drizzle-orm/src/prisma/pg/session.ts index b93f6f14b1..1a3b79a121 100644 --- a/drizzle-orm/src/prisma/pg/session.ts +++ b/drizzle-orm/src/prisma/pg/session.ts @@ -21,7 +21,7 @@ export class PrismaPgPreparedQuery extends PgPreparedQuery): Promise { diff --git a/drizzle-orm/src/query-builders/query-builder.ts b/drizzle-orm/src/query-builders/query-builder.ts index c23555aa34..cef2767aa5 100644 --- a/drizzle-orm/src/query-builders/query-builder.ts +++ b/drizzle-orm/src/query-builders/query-builder.ts @@ -1,12 +1,13 @@ import { entityKind } from '~/entity.ts'; import type { SQL, SQLWrapper } from '~/sql/index.ts'; -export abstract class TypedQueryBuilder implements SQLWrapper { +export abstract class TypedQueryBuilder implements SQLWrapper { static readonly [entityKind]: string = 'TypedQueryBuilder'; declare _: { selectedFields: TSelection; result: TResult; + config?: TConfig; }; /** @internal */ diff --git a/drizzle-orm/src/singlestore-core/db.ts b/drizzle-orm/src/singlestore-core/db.ts index ab8ce7bab2..02eb866586 100644 --- a/drizzle-orm/src/singlestore-core/db.ts +++ b/drizzle-orm/src/singlestore-core/db.ts @@ -1,4 +1,5 @@ import type { ResultSetHeader } from 'mysql2/promise'; +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { TypedQueryBuilder } from '~/query-builders/query-builder.ts'; import type { ExtractTablesWithRelations, RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts'; @@ -79,6 +80,7 @@ export class SingleStoreDatabase< // ); // } // } + this.$cache = { invalidate: async (_params: any) => {} }; } /** @@ -475,6 +477,8 @@ export class SingleStoreDatabase< return this.session.execute(typeof query === 'string' ? sql.raw(query) : query.getSQL()); } + $cache: { invalidate: Cache['onMutate'] }; + transaction( transaction: ( tx: SingleStoreTransaction, diff --git a/drizzle-orm/src/singlestore-core/query-builders/delete.ts b/drizzle-orm/src/singlestore-core/query-builders/delete.ts index 1f41d29ba9..91a1222663 100644 --- a/drizzle-orm/src/singlestore-core/query-builders/delete.ts +++ b/drizzle-orm/src/singlestore-core/query-builders/delete.ts @@ -17,6 +17,7 @@ import type { Subquery } from '~/subquery.ts'; import { Table } from '~/table.ts'; import type { ValueOrArray } from '~/utils.ts'; import type { SingleStoreColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { SelectedFieldsOrdered } from './select.types.ts'; export type SingleStoreDeleteWithout< @@ -185,6 +186,13 @@ export class SingleStoreDeleteBase< return this.session.prepareQuery( this.dialect.sqlToQuery(this.getSQL()), this.config.returning, + undefined, + undefined, + undefined, + { + type: 'delete', + tables: extractUsedTable(this.config.table), + }, ) as SingleStoreDeletePrepare; } diff --git a/drizzle-orm/src/singlestore-core/query-builders/insert.ts b/drizzle-orm/src/singlestore-core/query-builders/insert.ts index 84a72fdabe..626c1798f1 100644 --- a/drizzle-orm/src/singlestore-core/query-builders/insert.ts +++ b/drizzle-orm/src/singlestore-core/query-builders/insert.ts @@ -18,6 +18,7 @@ import type { InferModelFromColumns } from '~/table.ts'; import { Table } from '~/table.ts'; import { mapUpdateSet, orderSelectedFields } from '~/utils.ts'; import type { AnySingleStoreColumn, SingleStoreColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { SelectedFieldsOrdered } from './select.types.ts'; import type { SingleStoreUpdateSetSource } from './update.ts'; @@ -283,6 +284,10 @@ export class SingleStoreInsertBase< undefined, generatedIds, this.config.returning, + { + type: 'delete', + tables: extractUsedTable(this.config.table), + }, ) as SingleStoreInsertPrepare; } diff --git a/drizzle-orm/src/singlestore-core/query-builders/select.ts b/drizzle-orm/src/singlestore-core/query-builders/select.ts index d7e2ad00ee..75906d26bf 100644 --- a/drizzle-orm/src/singlestore-core/query-builders/select.ts +++ b/drizzle-orm/src/singlestore-core/query-builders/select.ts @@ -1,3 +1,4 @@ +import type { CacheConfig, WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind, is } from '~/entity.ts'; import { TypedQueryBuilder } from '~/query-builders/query-builder.ts'; import type { @@ -33,6 +34,7 @@ import { orderSelectedFields, type ValueOrArray, } from '~/utils.ts'; +import { extractUsedTable } from '../utils.ts'; import type { AnySingleStoreSelect, CreateSingleStoreSelectFromBuilderMode, @@ -153,6 +155,7 @@ export abstract class SingleStoreSelectQueryBuilderBase< readonly excludedMethods: TExcludedMethods; readonly result: TResult; readonly selectedFields: TSelectedFields; + readonly config: SingleStoreSelectConfig; }; protected config: SingleStoreSelectConfig; @@ -162,6 +165,8 @@ export abstract class SingleStoreSelectQueryBuilderBase< /** @internal */ readonly session: SingleStoreSession | undefined; protected dialect: SingleStoreDialect; + protected cacheConfig?: WithCacheConfig = undefined; + protected usedTables: Set = new Set(); constructor( { table, fields, isPartialSelect, session, dialect, withList, distinct }: { @@ -187,9 +192,16 @@ export abstract class SingleStoreSelectQueryBuilderBase< this.dialect = dialect; this._ = { selectedFields: fields as TSelectedFields, + config: this.config, } as this['_']; this.tableName = getTableLikeName(table); this.joinsNotNullableMap = typeof this.tableName === 'string' ? { [this.tableName]: true } : {}; + for (const item of extractUsedTable(table)) this.usedTables.add(item); + } + + /** @internal */ + getUsedTables() { + return [...this.usedTables]; } private createJoin< @@ -206,6 +218,9 @@ export abstract class SingleStoreSelectQueryBuilderBase< const baseTableName = this.tableName; const tableName = getTableLikeName(table); + // store all tables used in a query + for (const item of extractUsedTable(table)) this.usedTables.add(item); + if (typeof tableName === 'string' && this.config.joins?.some((join) => join.alias === tableName)) { throw new Error(`Alias "${tableName}" is already used in this query`); } @@ -894,8 +909,12 @@ export abstract class SingleStoreSelectQueryBuilderBase< as( alias: TAlias, ): SubqueryWithSelection { + const usedTables: string[] = []; + usedTables.push(...extractUsedTable(this.config.table)); + if (this.config.joins) { for (const it of this.config.joins) usedTables.push(...extractUsedTable(it.table)); } + return new Proxy( - new Subquery(this.getSQL(), this.config.fields, alias), + new Subquery(this.getSQL(), this.config.fields, alias, false, [...new Set(usedTables)]), new SelectionProxyHandler({ alias, sqlAliasedBehavior: 'alias', sqlBehavior: 'error' }), ) as SubqueryWithSelection; } @@ -973,11 +992,23 @@ export class SingleStoreSelectBase< const query = this.session.prepareQuery< SingleStorePreparedQueryConfig & { execute: SelectResult[] }, TPreparedQueryHKT - >(this.dialect.sqlToQuery(this.getSQL()), fieldsList); + >(this.dialect.sqlToQuery(this.getSQL()), fieldsList, undefined, undefined, undefined, { + type: 'select', + tables: [...this.usedTables], + }, this.cacheConfig); query.joinsNotNullableMap = this.joinsNotNullableMap; return query as SingleStoreSelectPrepare; } + $withCache(config?: { config?: CacheConfig; tag?: string; autoInvalidate?: boolean } | false) { + this.cacheConfig = config === undefined + ? { config: {}, enable: true, autoInvalidate: true } + : config === false + ? { enable: false } + : { enable: true, autoInvalidate: true, ...config }; + return this; + } + execute = ((placeholderValues) => { return this.prepare().execute(placeholderValues); }) as ReturnType['execute']; diff --git a/drizzle-orm/src/singlestore-core/query-builders/update.ts b/drizzle-orm/src/singlestore-core/query-builders/update.ts index 6a843373c9..90a42880de 100644 --- a/drizzle-orm/src/singlestore-core/query-builders/update.ts +++ b/drizzle-orm/src/singlestore-core/query-builders/update.ts @@ -18,6 +18,7 @@ import type { Subquery } from '~/subquery.ts'; import { Table } from '~/table.ts'; import { mapUpdateSet, type UpdateSet, type ValueOrArray } from '~/utils.ts'; import type { SingleStoreColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { SelectedFieldsOrdered } from './select.types.ts'; export interface SingleStoreUpdateConfig { @@ -230,6 +231,13 @@ export class SingleStoreUpdateBase< return this.session.prepareQuery( this.dialect.sqlToQuery(this.getSQL()), this.config.returning, + undefined, + undefined, + undefined, + { + type: 'delete', + tables: extractUsedTable(this.config.table), + }, ) as SingleStoreUpdatePrepare; } diff --git a/drizzle-orm/src/singlestore-core/session.ts b/drizzle-orm/src/singlestore-core/session.ts index bc31f3d973..1d8359d8fb 100644 --- a/drizzle-orm/src/singlestore-core/session.ts +++ b/drizzle-orm/src/singlestore-core/session.ts @@ -1,5 +1,8 @@ -import { entityKind } from '~/entity.ts'; +import { type Cache, hashQuery, NoopCache } from '~/cache/core/cache.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; +import { entityKind, is } from '~/entity.ts'; import { TransactionRollbackError } from '~/errors.ts'; +import { DrizzleQueryError } from '~/errors/index.ts'; import type { RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts'; import { type Query, type SQL, sql } from '~/sql/sql.ts'; import type { Assume, Equal } from '~/utils.ts'; @@ -43,6 +46,112 @@ export type PreparedQueryKind< export abstract class SingleStorePreparedQuery { static readonly [entityKind]: string = 'SingleStorePreparedQuery'; + constructor( + private cache?: Cache, + // per query related metadata + private queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + // config that was passed through $withCache + private cacheConfig?: WithCacheConfig, + ) { + // it means that no $withCache options were passed and it should be just enabled + if (cache && cache.strategy() === 'all' && cacheConfig === undefined) { + this.cacheConfig = { enable: true, autoInvalidate: true }; + } + if (!this.cacheConfig?.enable) { + this.cacheConfig = undefined; + } + } + + /** @internal */ + protected async queryWithCache( + queryString: string, + params: any[], + query: () => Promise, + ): Promise { + if (this.cache === undefined || is(this.cache, NoopCache) || this.queryMetadata === undefined) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // don't do any mutations, if globally is false + if (this.cacheConfig && !this.cacheConfig.enable) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // For mutate queries, we should query the database, wait for a response, and then perform invalidation + if ( + ( + this.queryMetadata.type === 'insert' || this.queryMetadata.type === 'update' + || this.queryMetadata.type === 'delete' + ) && this.queryMetadata.tables.length > 0 + ) { + try { + const [res] = await Promise.all([ + query(), + this.cache.onMutate({ tables: this.queryMetadata.tables }), + ]); + return res; + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // don't do any reads if globally disabled + if (!this.cacheConfig) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + if (this.queryMetadata.type === 'select') { + const fromCache = await this.cache.get( + this.cacheConfig.tag ?? await hashQuery(queryString, params), + this.queryMetadata.tables, + this.cacheConfig.tag !== undefined, + this.cacheConfig.autoInvalidate, + ); + if (fromCache === undefined) { + let result; + try { + result = await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + + // put actual key + await this.cache.put( + this.cacheConfig.tag ?? await hashQuery(queryString, params), + result, + // make sure we send tables that were used in a query only if user wants to invalidate it on each write + this.cacheConfig.autoInvalidate ? this.queryMetadata.tables : [], + this.cacheConfig.tag !== undefined, + this.cacheConfig.config, + ); + // put flag if we should invalidate or not + return result; + } + + return fromCache as unknown as T; + } + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + /** @internal */ joinsNotNullableMap?: Record; @@ -76,6 +185,11 @@ export abstract class SingleStoreSession< customResultMapper?: (rows: unknown[][]) => T['execute'], generatedIds?: Record[], returningIds?: SelectedFieldsOrdered, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PreparedQueryKind; execute(query: SQL): Promise { diff --git a/drizzle-orm/src/singlestore-core/utils.ts b/drizzle-orm/src/singlestore-core/utils.ts index 9ec8b7b0d3..b834ee42c4 100644 --- a/drizzle-orm/src/singlestore-core/utils.ts +++ b/drizzle-orm/src/singlestore-core/utils.ts @@ -1,4 +1,6 @@ import { is } from '~/entity.ts'; +import { SQL } from '~/sql/sql.ts'; +import { Subquery } from '~/subquery.ts'; import { Table } from '~/table.ts'; import type { Index } from './indexes.ts'; import { IndexBuilder } from './indexes.ts'; @@ -9,6 +11,19 @@ import { type UniqueConstraint, UniqueConstraintBuilder } from './unique-constra /* import { SingleStoreViewConfig } from './view-common.ts'; import type { SingleStoreView } from './view.ts'; */ +export function extractUsedTable(table: SingleStoreTable | Subquery | SQL): string[] { + if (is(table, SingleStoreTable)) { + return [`${table[Table.Symbol.BaseName]}`]; + } + if (is(table, Subquery)) { + return table._.usedTables ?? []; + } + if (is(table, SQL)) { + return table.usedTables ?? []; + } + return []; +} + export function getTableConfig(table: SingleStoreTable) { const columns = Object.values(table[SingleStoreTable.Symbol.Columns]); const indexes: Index[] = []; diff --git a/drizzle-orm/src/singlestore/driver.ts b/drizzle-orm/src/singlestore/driver.ts index 8e7db51f2f..a534ea83eb 100644 --- a/drizzle-orm/src/singlestore/driver.ts +++ b/drizzle-orm/src/singlestore/driver.ts @@ -1,5 +1,6 @@ import { type Connection as CallbackConnection, createPool, type Pool as CallbackPool, type PoolOptions } from 'mysql2'; import type { Connection, Pool } from 'mysql2/promise'; +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { DefaultLogger } from '~/logger.ts'; @@ -22,6 +23,7 @@ import { SingleStoreDriverSession } from './session.ts'; export interface SingleStoreDriverOptions { logger?: Logger; + cache?: Cache; } export class SingleStoreDriverDriver { @@ -37,7 +39,10 @@ export class SingleStoreDriverDriver { createSession( schema: RelationalSchemaConfig | undefined, ): SingleStoreDriverSession, TablesRelationalConfig> { - return new SingleStoreDriverSession(this.client, this.dialect, schema, { logger: this.options.logger }); + return new SingleStoreDriverSession(this.client, this.dialect, schema, { + logger: this.options.logger, + cache: this.options.cache, + }); } } @@ -85,10 +90,17 @@ function construct< }; } - const driver = new SingleStoreDriverDriver(clientForInstance as SingleStoreDriverClient, dialect, { logger }); + const driver = new SingleStoreDriverDriver(clientForInstance as SingleStoreDriverClient, dialect, { + logger, + cache: config.cache, + }); const session = driver.createSession(schema); const db = new SingleStoreDriverDatabase(dialect, session, schema as any) as SingleStoreDriverDatabase; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/singlestore/session.ts b/drizzle-orm/src/singlestore/session.ts index fd70a1d526..0b05e1c311 100644 --- a/drizzle-orm/src/singlestore/session.ts +++ b/drizzle-orm/src/singlestore/session.ts @@ -10,6 +10,8 @@ import type { RowDataPacket, } from 'mysql2/promise'; import { once } from 'node:events'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { Column } from '~/column.ts'; import { entityKind, is } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; @@ -52,14 +54,20 @@ export class SingleStoreDriverPreparedQuery T['execute'], // Keys that were used in $default and the value that was generated for them private generatedIds?: Record[], - // Keys that should be returned, it has the column with all properries + key from object + // Keys that should be returned, it has the column with all properties + key from object private returningIds?: SelectedFieldsOrdered, ) { - super(); + super(cache, queryMetadata, cacheConfig); this.rawQuery = { sql: queryString, // rowsAsArray: true, @@ -90,7 +98,9 @@ export class SingleStoreDriverPreparedQuery(rawQuery, params); + const res = await this.queryWithCache(rawQuery.sql, params, async () => { + return await client.query(rawQuery, params); + }); const insertId = res[0].insertId; const affectedRows = res[0].affectedRows; // for each row, I need to check keys from @@ -119,7 +129,9 @@ export class SingleStoreDriverPreparedQuery(query, params); + const result = await this.queryWithCache(query.sql, params, async () => { + return await client.query(query, params); + }); const rows = result[0]; if (customResultMapper) { @@ -184,6 +196,7 @@ export class SingleStoreDriverPreparedQuery( @@ -210,6 +225,11 @@ export class SingleStoreDriverSession< customResultMapper?: (rows: unknown[][]) => T['execute'], generatedIds?: Record[], returningIds?: SelectedFieldsOrdered, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PreparedQueryKind { // Add returningId fields // Each driver gets them from response from database @@ -218,6 +238,9 @@ export class SingleStoreDriverSession< query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, customResultMapper, generatedIds, diff --git a/drizzle-orm/src/sql/sql.ts b/drizzle-orm/src/sql/sql.ts index ec4feb20c2..994acf4faa 100644 --- a/drizzle-orm/src/sql/sql.ts +++ b/drizzle-orm/src/sql/sql.ts @@ -112,7 +112,22 @@ export class SQL implements SQLWrapper { decoder: DriverValueDecoder = noopDecoder; private shouldInlineParams = false; - constructor(readonly queryChunks: SQLChunk[]) {} + /** @internal */ + usedTables: string[] = []; + + constructor(readonly queryChunks: SQLChunk[]) { + for (const chunk of queryChunks) { + if (is(chunk, Table)) { + const schemaName = chunk[Table.Symbol.Schema]; + + this.usedTables.push( + schemaName === undefined + ? chunk[Table.Symbol.Name] + : schemaName + '.' + chunk[Table.Symbol.Name], + ); + } + } + } append(query: SQL): this { this.queryChunks.push(...query.queryChunks); diff --git a/drizzle-orm/src/sqlite-core/db.ts b/drizzle-orm/src/sqlite-core/db.ts index f5735155fb..9b6a175e81 100644 --- a/drizzle-orm/src/sqlite-core/db.ts +++ b/drizzle-orm/src/sqlite-core/db.ts @@ -1,3 +1,4 @@ +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { TypedQueryBuilder } from '~/query-builders/query-builder.ts'; import type { ExtractTablesWithRelations, RelationalSchemaConfig, TablesRelationalConfig } from '~/relations.ts'; @@ -85,6 +86,7 @@ export class BaseSQLiteDatabase< ) as typeof query[keyof TSchema]; } } + this.$cache = { invalidate: async (_params: any) => {} }; } /** @@ -470,6 +472,8 @@ export class BaseSQLiteDatabase< return new SQLiteUpdateBuilder(table, this.session, this.dialect); } + $cache: { invalidate: Cache['onMutate'] }; + /** * Creates an insert query. * diff --git a/drizzle-orm/src/sqlite-core/query-builders/delete.ts b/drizzle-orm/src/sqlite-core/query-builders/delete.ts index 53e8d6227c..654d640661 100644 --- a/drizzle-orm/src/sqlite-core/query-builders/delete.ts +++ b/drizzle-orm/src/sqlite-core/query-builders/delete.ts @@ -11,6 +11,7 @@ import type { Subquery } from '~/subquery.ts'; import { Table } from '~/table.ts'; import { type DrizzleTypeError, orderSelectedFields, type ValueOrArray } from '~/utils.ts'; import type { SQLiteColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import type { SelectedFieldsFlat, SelectedFieldsOrdered } from './select.types.ts'; export type SQLiteDeleteWithout< @@ -268,6 +269,11 @@ export class SQLiteDeleteBase< this.config.returning, this.config.returning ? 'all' : 'run', true, + undefined, + { + type: 'delete', + tables: extractUsedTable(this.config.table), + }, ) as SQLiteDeletePrepare; } diff --git a/drizzle-orm/src/sqlite-core/query-builders/insert.ts b/drizzle-orm/src/sqlite-core/query-builders/insert.ts index 7609162c3a..96985dc834 100644 --- a/drizzle-orm/src/sqlite-core/query-builders/insert.ts +++ b/drizzle-orm/src/sqlite-core/query-builders/insert.ts @@ -13,6 +13,7 @@ import type { Subquery } from '~/subquery.ts'; import { Columns, Table } from '~/table.ts'; import { type DrizzleTypeError, haveSameKeys, mapUpdateSet, orderSelectedFields, type Simplify } from '~/utils.ts'; import type { AnySQLiteColumn, SQLiteColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import { QueryBuilder } from './query-builder.ts'; import type { SelectedFieldsFlat, SelectedFieldsOrdered } from './select.types.ts'; import type { SQLiteUpdateSetSource } from './update.ts'; @@ -381,6 +382,11 @@ export class SQLiteInsertBase< this.config.returning, this.config.returning ? 'all' : 'run', true, + undefined, + { + type: 'insert', + tables: extractUsedTable(this.config.table), + }, ) as SQLiteInsertPrepare; } diff --git a/drizzle-orm/src/sqlite-core/query-builders/select.ts b/drizzle-orm/src/sqlite-core/query-builders/select.ts index 3b536597db..ce564b4453 100644 --- a/drizzle-orm/src/sqlite-core/query-builders/select.ts +++ b/drizzle-orm/src/sqlite-core/query-builders/select.ts @@ -1,3 +1,4 @@ +import type { CacheConfig, WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind, is } from '~/entity.ts'; import { TypedQueryBuilder } from '~/query-builders/query-builder.ts'; import type { @@ -31,6 +32,7 @@ import { type ValueOrArray, } from '~/utils.ts'; import { ViewBaseConfig } from '~/view-common.ts'; +import { extractUsedTable } from '../utils.ts'; import { SQLiteViewBase } from '../view-base.ts'; import type { AnySQLiteSelect, @@ -152,6 +154,7 @@ export abstract class SQLiteSelectQueryBuilderBase< readonly excludedMethods: TExcludedMethods; readonly result: TResult; readonly selectedFields: TSelectedFields; + readonly config: SQLiteSelectConfig; }; /** @internal */ @@ -161,6 +164,8 @@ export abstract class SQLiteSelectQueryBuilderBase< private isPartialSelect: boolean; protected session: SQLiteSession | undefined; protected dialect: SQLiteDialect; + protected cacheConfig?: WithCacheConfig = undefined; + protected usedTables: Set = new Set(); constructor( { table, fields, isPartialSelect, session, dialect, withList, distinct }: { @@ -186,9 +191,16 @@ export abstract class SQLiteSelectQueryBuilderBase< this.dialect = dialect; this._ = { selectedFields: fields as TSelectedFields, + config: this.config, } as this['_']; this.tableName = getTableLikeName(table); this.joinsNotNullableMap = typeof this.tableName === 'string' ? { [this.tableName]: true } : {}; + for (const item of extractUsedTable(table)) this.usedTables.add(item); + } + + /** @internal */ + getUsedTables() { + return [...this.usedTables]; } private createJoin( @@ -201,6 +213,9 @@ export abstract class SQLiteSelectQueryBuilderBase< const baseTableName = this.tableName; const tableName = getTableLikeName(table); + // store all tables used in a query + for (const item of extractUsedTable(table)) this.usedTables.add(item); + if (typeof tableName === 'string' && this.config.joins?.some((join) => join.alias === tableName)) { throw new Error(`Alias "${tableName}" is already used in this query`); } @@ -809,8 +824,12 @@ export abstract class SQLiteSelectQueryBuilderBase< as( alias: TAlias, ): SubqueryWithSelection { + const usedTables: string[] = []; + usedTables.push(...extractUsedTable(this.config.table)); + if (this.config.joins) { for (const it of this.config.joins) usedTables.push(...extractUsedTable(it.table)); } + return new Proxy( - new Subquery(this.getSQL(), this.config.fields, alias), + new Subquery(this.getSQL(), this.config.fields, alias, false, [...new Set(usedTables)]), new SelectionProxyHandler({ alias, sqlAliasedBehavior: 'alias', sqlBehavior: 'error' }), ) as SubqueryWithSelection; } @@ -896,11 +915,26 @@ export class SQLiteSelectBase< fieldsList, 'all', true, + undefined, + { + type: 'select', + tables: [...this.usedTables], + }, + this.cacheConfig, ); query.joinsNotNullableMap = this.joinsNotNullableMap; return query as ReturnType; } + $withCache(config?: { config?: CacheConfig; tag?: string; autoInvalidate?: boolean } | false) { + this.cacheConfig = config === undefined + ? { config: {}, enable: true, autoInvalidate: true } + : config === false + ? { enable: false } + : { enable: true, autoInvalidate: true, ...config }; + return this; + } + prepare(): SQLiteSelectPrepare { return this._prepare(false); } diff --git a/drizzle-orm/src/sqlite-core/query-builders/update.ts b/drizzle-orm/src/sqlite-core/query-builders/update.ts index 6915d60a92..a6b3c4ce3a 100644 --- a/drizzle-orm/src/sqlite-core/query-builders/update.ts +++ b/drizzle-orm/src/sqlite-core/query-builders/update.ts @@ -20,6 +20,7 @@ import { } from '~/utils.ts'; import { ViewBaseConfig } from '~/view-common.ts'; import type { SQLiteColumn } from '../columns/common.ts'; +import { extractUsedTable } from '../utils.ts'; import { SQLiteViewBase } from '../view-base.ts'; import type { SelectedFields, SelectedFieldsOrdered, SQLiteSelectJoinConfig } from './select.types.ts'; @@ -426,6 +427,11 @@ export class SQLiteUpdateBase< this.config.returning, this.config.returning ? 'all' : 'run', true, + undefined, + { + type: 'insert', + tables: extractUsedTable(this.config.table), + }, ) as SQLiteUpdatePrepare; } diff --git a/drizzle-orm/src/sqlite-core/session.ts b/drizzle-orm/src/sqlite-core/session.ts index 9e6924ca08..13d002d768 100644 --- a/drizzle-orm/src/sqlite-core/session.ts +++ b/drizzle-orm/src/sqlite-core/session.ts @@ -1,11 +1,13 @@ -import { entityKind } from '~/entity.ts'; +import { type Cache, hashQuery, NoopCache } from '~/cache/core/cache.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; +import { entityKind, is } from '~/entity.ts'; import { DrizzleError, TransactionRollbackError } from '~/errors.ts'; +import { DrizzleQueryError } from '~/errors/index.ts'; +import { QueryPromise } from '~/query-promise.ts'; import type { TablesRelationalConfig } from '~/relations.ts'; import type { PreparedQuery } from '~/session.ts'; import type { Query, SQL } from '~/sql/sql.ts'; import type { SQLiteAsyncDialect, SQLiteSyncDialect } from '~/sqlite-core/dialect.ts'; -// import { QueryPromise } from '../index.ts'; -import { QueryPromise } from '~/query-promise.ts'; import { BaseSQLiteDatabase } from './db.ts'; import type { SQLiteRaw } from './query-builders/raw.ts'; import type { SelectedFieldsOrdered } from './query-builders/select.types.ts'; @@ -48,7 +50,110 @@ export abstract class SQLitePreparedQuery impleme private mode: 'sync' | 'async', private executeMethod: SQLiteExecuteMethod, protected query: Query, - ) {} + private cache?: Cache, + // per query related metadata + private queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + // config that was passed through $withCache + private cacheConfig?: WithCacheConfig, + ) { + // it means that no $withCache options were passed and it should be just enabled + if (cache && cache.strategy() === 'all' && cacheConfig === undefined) { + this.cacheConfig = { enable: true, autoInvalidate: true }; + } + if (!this.cacheConfig?.enable) { + this.cacheConfig = undefined; + } + } + + /** @internal */ + protected async queryWithCache( + queryString: string, + params: any[], + query: () => Promise, + ): Promise { + if (this.cache === undefined || is(this.cache, NoopCache) || this.queryMetadata === undefined) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // don't do any mutations, if globally is false + if (this.cacheConfig && !this.cacheConfig.enable) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // For mutate queries, we should query the database, wait for a response, and then perform invalidation + if ( + ( + this.queryMetadata.type === 'insert' || this.queryMetadata.type === 'update' + || this.queryMetadata.type === 'delete' + ) && this.queryMetadata.tables.length > 0 + ) { + try { + const [res] = await Promise.all([ + query(), + this.cache.onMutate({ tables: this.queryMetadata.tables }), + ]); + return res; + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + // don't do any reads if globally disabled + if (!this.cacheConfig) { + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } + + if (this.queryMetadata.type === 'select') { + const fromCache = await this.cache.get( + this.cacheConfig.tag ?? await hashQuery(queryString, params), + this.queryMetadata.tables, + this.cacheConfig.tag !== undefined, + this.cacheConfig.autoInvalidate, + ); + if (fromCache === undefined) { + let result; + try { + result = await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + + // put actual key + await this.cache.put( + this.cacheConfig.tag ?? await hashQuery(queryString, params), + result, + // make sure we send tables that were used in a query only if user wants to invalidate it on each write + this.cacheConfig.autoInvalidate ? this.queryMetadata.tables : [], + this.cacheConfig.tag !== undefined, + this.cacheConfig.config, + ); + // put flag if we should invalidate or not + return result; + } + + return fromCache as unknown as T; + } + try { + return await query(); + } catch (e) { + throw new DrizzleQueryError(queryString, params, e as Error); + } + } getQuery(): Query { return this.query; @@ -124,6 +229,11 @@ export abstract class SQLiteSession< executeMethod: SQLiteExecuteMethod, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][], mapColumnValue?: (value: unknown) => unknown) => unknown, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): SQLitePreparedQuery; prepareOneTimeQuery( @@ -131,8 +241,22 @@ export abstract class SQLiteSession< fields: SelectedFieldsOrdered | undefined, executeMethod: SQLiteExecuteMethod, isResponseInArrayMode: boolean, + customResultMapper?: (rows: unknown[][], mapColumnValue?: (value: unknown) => unknown) => unknown, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): SQLitePreparedQuery { - return this.prepareQuery(query, fields, executeMethod, isResponseInArrayMode); + return this.prepareQuery( + query, + fields, + executeMethod, + isResponseInArrayMode, + customResultMapper, + queryMetadata, + cacheConfig, + ); } abstract transaction( diff --git a/drizzle-orm/src/sqlite-core/utils.ts b/drizzle-orm/src/sqlite-core/utils.ts index 7d21483b08..361456c029 100644 --- a/drizzle-orm/src/sqlite-core/utils.ts +++ b/drizzle-orm/src/sqlite-core/utils.ts @@ -1,4 +1,6 @@ import { is } from '~/entity.ts'; +import { SQL } from '~/sql/sql.ts'; +import { Subquery } from '~/subquery.ts'; import { Table } from '~/table.ts'; import { ViewBaseConfig } from '~/view-common.ts'; import type { Check } from './checks.ts'; @@ -11,6 +13,7 @@ import type { PrimaryKey } from './primary-keys.ts'; import { PrimaryKeyBuilder } from './primary-keys.ts'; import { SQLiteTable } from './table.ts'; import { type UniqueConstraint, UniqueConstraintBuilder } from './unique-constraint.ts'; +import type { SQLiteViewBase } from './view-base.ts'; import type { SQLiteView } from './view.ts'; export function getTableConfig(table: TTable) { @@ -53,6 +56,19 @@ export function getTableConfig(table: TTable) { }; } +export function extractUsedTable(table: SQLiteTable | Subquery | SQLiteViewBase | SQL): string[] { + if (is(table, SQLiteTable)) { + return [`${table[Table.Symbol.BaseName]}`]; + } + if (is(table, Subquery)) { + return table._.usedTables ?? []; + } + if (is(table, SQL)) { + return table.usedTables ?? []; + } + return []; +} + export type OnConflict = 'rollback' | 'abort' | 'fail' | 'ignore' | 'replace'; export function getViewConfig< diff --git a/drizzle-orm/src/sqlite-proxy/driver.ts b/drizzle-orm/src/sqlite-proxy/driver.ts index e11e977c1a..2e8663a30a 100644 --- a/drizzle-orm/src/sqlite-proxy/driver.ts +++ b/drizzle-orm/src/sqlite-proxy/driver.ts @@ -57,6 +57,7 @@ export function drizzle = Record { const dialect = new SQLiteAsyncDialect({ casing: config?.casing }); let logger; + let cache; let _batchCallback: AsyncBatchRemoteCallback | undefined; let _config: DrizzleConfig = {}; @@ -73,6 +74,7 @@ export function drizzle = Record = Record; + const session = new SQLiteRemoteSession(callback, dialect, schema, _batchCallback, { logger, cache }); + const db = new SqliteRemoteDatabase('async', dialect, session, schema) as SqliteRemoteDatabase; + ( db).$cache = cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = cache?.onMutate; + } + return db; } diff --git a/drizzle-orm/src/sqlite-proxy/session.ts b/drizzle-orm/src/sqlite-proxy/session.ts index 93d277d694..bdf94cc6cc 100644 --- a/drizzle-orm/src/sqlite-proxy/session.ts +++ b/drizzle-orm/src/sqlite-proxy/session.ts @@ -1,4 +1,6 @@ import type { BatchItem } from '~/batch.ts'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -19,6 +21,7 @@ import type { AsyncBatchRemoteCallback, AsyncRemoteCallback, RemoteCallback, Sql export interface SQLiteRemoteSessionOptions { logger?: Logger; + cache?: Cache; } export type PreparedQueryConfig = Omit; @@ -30,6 +33,7 @@ export class SQLiteRemoteSession< static override readonly [entityKind]: string = 'SQLiteRemoteSession'; private logger: Logger; + private cache: Cache; constructor( private client: RemoteCallback, @@ -40,6 +44,7 @@ export class SQLiteRemoteSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery>( @@ -48,11 +53,19 @@ export class SQLiteRemoteSession< executeMethod: SQLiteExecuteMethod, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => unknown, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): RemotePreparedQuery { return new RemotePreparedQuery( this.client, query, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, executeMethod, isResponseInArrayMode, @@ -138,6 +151,12 @@ export class RemotePreparedQuery unknown, ) => unknown, ) { - super('async', executeMethod, query); + super('async', executeMethod, query, cache, queryMetadata, cacheConfig); this.customResultMapper = customResultMapper; this.method = executeMethod; } @@ -155,12 +174,12 @@ export class RemotePreparedQuery): Promise { + async run(placeholderValues?: Record): Promise { const params = fillPlaceholders(this.query.params, placeholderValues ?? {}); this.logger.logQuery(this.query.sql, params); - return (this.client as AsyncRemoteCallback)(this.query.sql, params, 'run') as Promise< - SqliteRemoteResult - >; + return await this.queryWithCache(this.query.sql, params, async () => { + return await (this.client as AsyncRemoteCallback)(this.query.sql, params, 'run'); + }); } override mapAllResult(rows: unknown, isFromBatch?: boolean): unknown { @@ -191,7 +210,9 @@ export class RemotePreparedQuery { + return await (client as AsyncRemoteCallback)(query.sql, params, 'all'); + }); return this.mapAllResult(rows); } @@ -201,7 +222,9 @@ export class RemotePreparedQuery { + return await (client as AsyncRemoteCallback)(query.sql, params, 'get'); + }); return this.mapGetResult(clientResult.rows); } @@ -235,7 +258,9 @@ export class RemotePreparedQuery(placeholderValues?: Record): Promise { const params = fillPlaceholders(this.query.params, placeholderValues ?? {}); this.logger.logQuery(this.query.sql, params); - const clientResult = await (this.client as AsyncRemoteCallback)(this.query.sql, params, 'values'); + const clientResult = await this.queryWithCache(this.query.sql, params, async () => { + return await (this.client as AsyncRemoteCallback)(this.query.sql, params, 'values'); + }); return clientResult.rows as T[]; } diff --git a/drizzle-orm/src/subquery.ts b/drizzle-orm/src/subquery.ts index c2303cc710..6e2393f4e3 100644 --- a/drizzle-orm/src/subquery.ts +++ b/drizzle-orm/src/subquery.ts @@ -21,15 +21,17 @@ export class Subquery< selectedFields: TSelectedFields; alias: TAlias; isWith: boolean; + usedTables?: string[]; }; - constructor(sql: SQL, selection: Record, alias: string, isWith = false) { + constructor(sql: SQL, fields: TSelectedFields, alias: string, isWith = false, usedTables: string[] = []) { this._ = { brand: 'Subquery', sql, - selectedFields: selection as TSelectedFields, + selectedFields: fields as TSelectedFields, alias: alias as TAlias, isWith, + usedTables, }; } diff --git a/drizzle-orm/src/tidb-serverless/driver.ts b/drizzle-orm/src/tidb-serverless/driver.ts index 69b9a0e44f..4ec77c6983 100644 --- a/drizzle-orm/src/tidb-serverless/driver.ts +++ b/drizzle-orm/src/tidb-serverless/driver.ts @@ -16,6 +16,7 @@ import { TiDBServerlessSession } from './session.ts'; export interface TiDBServerlessSDriverOptions { logger?: Logger; + cache?: Cache; } export class TiDBServerlessDatabase< @@ -51,9 +52,13 @@ function construct = Record; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/tidb-serverless/session.ts b/drizzle-orm/src/tidb-serverless/session.ts index 279c60f3b9..bf555a5c7f 100644 --- a/drizzle-orm/src/tidb-serverless/session.ts +++ b/drizzle-orm/src/tidb-serverless/session.ts @@ -1,4 +1,6 @@ import type { Connection, ExecuteOptions, FullResult, Tx } from '@tidbcloud/serverless'; +import { type Cache, NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { Column } from '~/column.ts'; import { entityKind, is } from '~/entity.ts'; @@ -29,6 +31,12 @@ export class TiDBServerlessPreparedQuery ext private queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, private customResultMapper?: (rows: unknown[][]) => T['execute'], // Keys that were used in $default and the value that was generated for them @@ -36,7 +44,7 @@ export class TiDBServerlessPreparedQuery ext // Keys that should be returned, it has the column with all properries + key from object private returningIds?: SelectedFieldsOrdered, ) { - super(); + super(cache, queryMetadata, cacheConfig); } async execute(placeholderValues: Record | undefined = {}): Promise { @@ -46,7 +54,9 @@ export class TiDBServerlessPreparedQuery ext const { fields, client, queryString, joinsNotNullableMap, customResultMapper, returningIds, generatedIds } = this; if (!fields && !customResultMapper) { - const res = await client.execute(queryString, params, executeRawConfig) as FullResult; + const res = await this.queryWithCache(queryString, params, async () => { + return await client.execute(queryString, params, executeRawConfig) as FullResult; + }); const insertId = res.lastInsertId ?? 0; const affectedRows = res.rowsAffected ?? 0; // for each row, I need to check keys from @@ -75,7 +85,9 @@ export class TiDBServerlessPreparedQuery ext return res; } - const rows = await client.execute(queryString, params, queryConfig) as unknown[][]; + const rows = await this.queryWithCache(queryString, params, async () => { + return await client.execute(queryString, params, queryConfig) as unknown[][]; + }); if (customResultMapper) { return customResultMapper(rows); @@ -91,6 +103,7 @@ export class TiDBServerlessPreparedQuery ext export interface TiDBServerlessSessionOptions { logger?: Logger; + cache?: Cache; } export class TiDBServerlessSession< @@ -101,6 +114,7 @@ export class TiDBServerlessSession< private logger: Logger; private client: Tx | Connection; + private cache: Cache; constructor( private baseClient: Connection, @@ -112,6 +126,7 @@ export class TiDBServerlessSession< super(dialect); this.client = tx ?? baseClient; this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -120,12 +135,20 @@ export class TiDBServerlessSession< customResultMapper?: (rows: unknown[][]) => T['execute'], generatedIds?: Record[], returningIds?: SelectedFieldsOrdered, + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): MySqlPreparedQuery { return new TiDBServerlessPreparedQuery( this.client, query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, customResultMapper, generatedIds, diff --git a/drizzle-orm/src/utils.ts b/drizzle-orm/src/utils.ts index 4645c5517c..fb17dd6a85 100644 --- a/drizzle-orm/src/utils.ts +++ b/drizzle-orm/src/utils.ts @@ -1,3 +1,4 @@ +import type { Cache } from './cache/core/cache.ts'; import type { AnyColumn } from './column.ts'; import { Column } from './column.ts'; import { is } from './entity.ts'; @@ -219,6 +220,7 @@ export interface DrizzleConfig = Record< logger?: boolean | Logger; schema?: TSchema; casing?: Casing; + cache?: Cache; } export type ValidateShape = T extends ValidShape ? Exclude extends never ? TResult diff --git a/drizzle-orm/src/vercel-postgres/driver.ts b/drizzle-orm/src/vercel-postgres/driver.ts index 3a77887833..c7f18c7739 100644 --- a/drizzle-orm/src/vercel-postgres/driver.ts +++ b/drizzle-orm/src/vercel-postgres/driver.ts @@ -1,4 +1,5 @@ import { sql } from '@vercel/postgres'; +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { DefaultLogger } from '~/logger.ts'; @@ -15,6 +16,7 @@ import { type VercelPgClient, type VercelPgQueryResultHKT, VercelPgSession } fro export interface VercelPgDriverOptions { logger?: Logger; + cache?: Cache; } export class VercelPgDriver { @@ -30,7 +32,10 @@ export class VercelPgDriver { createSession( schema: RelationalSchemaConfig | undefined, ): VercelPgSession, TablesRelationalConfig> { - return new VercelPgSession(this.client, this.dialect, schema, { logger: this.options.logger }); + return new VercelPgSession(this.client, this.dialect, schema, { + logger: this.options.logger, + cache: this.options.cache, + }); } } @@ -67,10 +72,14 @@ function construct = Record; ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/vercel-postgres/session.ts b/drizzle-orm/src/vercel-postgres/session.ts index 547e3b4cff..c20c0b223a 100644 --- a/drizzle-orm/src/vercel-postgres/session.ts +++ b/drizzle-orm/src/vercel-postgres/session.ts @@ -8,6 +8,9 @@ import { VercelPool, type VercelPoolClient, } from '@vercel/postgres'; +import type { Cache } from '~/cache/core/cache.ts'; +import { NoopCache } from '~/cache/core/cache.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import { type Logger, NoopLogger } from '~/logger.ts'; import { type PgDialect, PgTransaction } from '~/pg-core/index.ts'; @@ -31,12 +34,18 @@ export class VercelPgPreparedQuery extends PgPrep queryString: string, private params: unknown[], private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, name: string | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], ) { - super({ sql: queryString, params }); + super({ sql: queryString, params }, cache, queryMetadata, cacheConfig); this.rawQuery = { name, text: queryString, @@ -133,10 +142,14 @@ export class VercelPgPreparedQuery extends PgPrep const { fields, rawQuery, client, queryConfig: query, joinsNotNullableMap, customResultMapper } = this; if (!fields && !customResultMapper) { - return client.query(rawQuery, params); + return this.queryWithCache(rawQuery.text, params, async () => { + return await client.query(rawQuery, params); + }); } - const { rows } = await client.query(query, params); + const { rows } = await this.queryWithCache(query.text, params, async () => { + return await client.query(query, params); + }); if (customResultMapper) { return customResultMapper(rows); @@ -148,13 +161,17 @@ export class VercelPgPreparedQuery extends PgPrep all(placeholderValues: Record | undefined = {}): Promise { const params = fillPlaceholders(this.params, placeholderValues); this.logger.logQuery(this.rawQuery.text, params); - return this.client.query(this.rawQuery, params).then((result) => result.rows); + return this.queryWithCache(this.rawQuery.text, params, async () => { + return await this.client.query(this.rawQuery, params); + }).then((result) => result.rows); } values(placeholderValues: Record | undefined = {}): Promise { const params = fillPlaceholders(this.params, placeholderValues); this.logger.logQuery(this.rawQuery.text, params); - return this.client.query(this.queryConfig, params).then((result) => result.rows); + return this.queryWithCache(this.queryConfig.text, params, async () => { + return await this.client.query(this.queryConfig, params); + }).then((result) => result.rows); } /** @internal */ @@ -165,6 +182,7 @@ export class VercelPgPreparedQuery extends PgPrep export interface VercelPgSessionOptions { logger?: Logger; + cache?: Cache; } export class VercelPgSession< @@ -174,6 +192,7 @@ export class VercelPgSession< static override readonly [entityKind]: string = 'VercelPgSession'; private logger: Logger; + private cache: Cache; constructor( private client: VercelPgClient, @@ -183,6 +202,7 @@ export class VercelPgSession< ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -191,12 +211,20 @@ export class VercelPgSession< name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PgPreparedQuery { return new VercelPgPreparedQuery( this.client, query.sql, query.params, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, name, isResponseInArrayMode, diff --git a/drizzle-orm/src/xata-http/driver.ts b/drizzle-orm/src/xata-http/driver.ts index ce275a88d3..e878ae6a98 100644 --- a/drizzle-orm/src/xata-http/driver.ts +++ b/drizzle-orm/src/xata-http/driver.ts @@ -1,3 +1,4 @@ +import type { Cache } from '~/cache/core/cache.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { DefaultLogger } from '~/logger.ts'; @@ -11,6 +12,7 @@ import { XataHttpSession } from './session.ts'; export interface XataDriverOptions { logger?: Logger; + cache?: Cache; } export class XataHttpDriver { @@ -29,6 +31,7 @@ export class XataHttpDriver { ): XataHttpSession, TablesRelationalConfig> { return new XataHttpSession(this.client, this.dialect, schema, { logger: this.options.logger, + cache: this.options.cache, }); } @@ -70,7 +73,7 @@ export function drizzle = Record = Record> | undefined, ); ( db).$client = client; + ( db).$cache = config.cache; + if (( db).$cache) { + ( db).$cache['invalidate'] = config.cache?.onMutate; + } return db as any; } diff --git a/drizzle-orm/src/xata-http/session.ts b/drizzle-orm/src/xata-http/session.ts index df4cc10030..34806be99e 100644 --- a/drizzle-orm/src/xata-http/session.ts +++ b/drizzle-orm/src/xata-http/session.ts @@ -1,4 +1,7 @@ import type { SQLPluginResult, SQLQueryResult } from '@xata.io/client'; +import type { Cache } from '~/cache/core/index.ts'; +import { NoopCache } from '~/cache/core/index.ts'; +import type { WithCacheConfig } from '~/cache/core/types.ts'; import { entityKind } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { NoopLogger } from '~/logger.ts'; @@ -28,11 +31,17 @@ export class XataHttpPreparedQuery extends PgPrep private client: XataHttpClient, query: Query, private logger: Logger, + cache: Cache, + queryMetadata: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + } | undefined, + cacheConfig: WithCacheConfig | undefined, private fields: SelectedFieldsOrdered | undefined, private _isResponseInArrayMode: boolean, private customResultMapper?: (rows: unknown[][]) => T['execute'], ) { - super(query); + super(query, cache, queryMetadata, cacheConfig); } async execute(placeholderValues: Record | undefined = {}): Promise { @@ -43,11 +52,15 @@ export class XataHttpPreparedQuery extends PgPrep const { fields, client, query, customResultMapper, joinsNotNullableMap } = this; if (!fields && !customResultMapper) { - return await client.sql>({ statement: query.sql, params }); - // return { rowCount: result.records.length, rows: result.records, rowAsArray: false }; + return this.queryWithCache(query.sql, params, async () => { + return await client.sql>({ statement: query.sql, params }); + }); } - const { rows, warning } = await client.sql({ statement: query.sql, params, responseType: 'array' }); + const { rows, warning } = await this.queryWithCache(query.sql, params, async () => { + return await client.sql({ statement: query.sql, params, responseType: 'array' }); + }); + if (warning) console.warn(warning); return customResultMapper @@ -58,13 +71,17 @@ export class XataHttpPreparedQuery extends PgPrep all(placeholderValues: Record | undefined = {}): Promise { const params = fillPlaceholders(this.query.params, placeholderValues); this.logger.logQuery(this.query.sql, params); - return this.client.sql({ statement: this.query.sql, params, responseType: 'array' }).then((result) => result.rows); + return this.queryWithCache(this.query.sql, params, async () => { + return this.client.sql({ statement: this.query.sql, params, responseType: 'array' }); + }).then((result) => result.rows); } values(placeholderValues: Record | undefined = {}): Promise { const params = fillPlaceholders(this.query.params, placeholderValues); this.logger.logQuery(this.query.sql, params); - return this.client.sql({ statement: this.query.sql, params }).then((result) => result.records); + return this.queryWithCache(this.query.sql, params, async () => { + return this.client.sql({ statement: this.query.sql, params }); + }).then((result) => result.records); } /** @internal */ @@ -75,6 +92,7 @@ export class XataHttpPreparedQuery extends PgPrep export interface XataHttpSessionOptions { logger?: Logger; + cache?: Cache; } export class XataHttpSession, TSchema extends TablesRelationalConfig> @@ -87,6 +105,7 @@ export class XataHttpSession, TSchem static override readonly [entityKind]: string = 'XataHttpSession'; private logger: Logger; + private cache: Cache; constructor( private client: XataHttpClient, @@ -96,6 +115,7 @@ export class XataHttpSession, TSchem ) { super(dialect); this.logger = options.logger ?? new NoopLogger(); + this.cache = options.cache ?? new NoopCache(); } prepareQuery( @@ -104,11 +124,19 @@ export class XataHttpSession, TSchem name: string | undefined, isResponseInArrayMode: boolean, customResultMapper?: (rows: unknown[][]) => T['execute'], + queryMetadata?: { + type: 'select' | 'update' | 'delete' | 'insert'; + tables: string[]; + }, + cacheConfig?: WithCacheConfig, ): PgPreparedQuery { return new XataHttpPreparedQuery( this.client, query, this.logger, + this.cache, + queryMetadata, + cacheConfig, fields, isResponseInArrayMode, customResultMapper, diff --git a/integration-tests/package.json b/integration-tests/package.json index 4508712113..13902a358b 100644 --- a/integration-tests/package.json +++ b/integration-tests/package.json @@ -28,9 +28,11 @@ "@types/sql.js": "^1.4.4", "@types/uuid": "^9.0.1", "@types/ws": "^8.5.10", + "@upstash/redis": "^1.34.3", "@vitest/ui": "^1.6.0", "ava": "^5.3.0", "cross-env": "^7.0.3", + "keyv": "^5.2.3", "import-in-the-middle": "^1.13.1", "ts-node": "^10.9.2", "tsx": "^4.14.0", diff --git a/integration-tests/tests/gel/cache.ts b/integration-tests/tests/gel/cache.ts new file mode 100644 index 0000000000..b8436aaaae --- /dev/null +++ b/integration-tests/tests/gel/cache.ts @@ -0,0 +1,73 @@ +import { getTableName, is, Table } from 'drizzle-orm'; +import type { MutationOption } from 'drizzle-orm/cache/core'; +import { Cache } from 'drizzle-orm/cache/core'; +import type { CacheConfig } from 'drizzle-orm/cache/core/types'; +import Keyv from 'keyv'; + +// eslint-disable-next-line drizzle-internal/require-entity-kind +export class TestGlobalCache extends Cache { + private globalTtl: number = 1000; + private usedTablesPerKey: Record = {}; + + constructor(private kv: Keyv = new Keyv()) { + super(); + } + + override strategy(): 'explicit' | 'all' { + return 'all'; + } + override async get(key: string, _tables: string[], _isTag: boolean): Promise { + const res = await this.kv.get(key) ?? undefined; + return res; + } + override async put( + key: string, + response: any, + tables: string[], + isTag: boolean, + config?: CacheConfig, + ): Promise { + await this.kv.set(key, response, config ? config.ex : this.globalTtl); + for (const table of tables) { + const keys = this.usedTablesPerKey[table]; + if (keys === undefined) { + this.usedTablesPerKey[table] = [key]; + } else { + keys.push(key); + } + } + } + override async onMutate(params: MutationOption): Promise { + const tagsArray = params.tags ? Array.isArray(params.tags) ? params.tags : [params.tags] : []; + const tablesArray = params.tables ? Array.isArray(params.tables) ? params.tables : [params.tables] : []; + + const keysToDelete = new Set(); + + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + const keys = this.usedTablesPerKey[tableName] ?? []; + for (const key of keys) keysToDelete.add(key); + } + + if (keysToDelete.size > 0 || tagsArray.length > 0) { + for (const tag of tagsArray) { + await this.kv.delete(tag); + } + + for (const key of keysToDelete) { + await this.kv.delete(key); + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + this.usedTablesPerKey[tableName] = []; + } + } + } + } +} + +// eslint-disable-next-line drizzle-internal/require-entity-kind +export class TestCache extends TestGlobalCache { + override strategy(): 'explicit' | 'all' { + return 'explicit'; + } +} diff --git a/integration-tests/tests/gel/gel.test.ts b/integration-tests/tests/gel/gel.test.ts index 31d6cfa53b..858a1d21ce 100644 --- a/integration-tests/tests/gel/gel.test.ts +++ b/integration-tests/tests/gel/gel.test.ts @@ -77,9 +77,10 @@ import createClient, { RelativeDuration, } from 'gel'; import { v4 as uuidV4 } from 'uuid'; -import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, test } from 'vitest'; +import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, test, vi } from 'vitest'; import { Expect } from '~/utils'; import 'zx/globals'; +import { TestCache, TestGlobalCache } from './cache'; import { createDockerDB } from './createInstance'; $.quiet = true; @@ -88,6 +89,8 @@ const ENABLE_LOGGING = false; let client: Client; let db: GelJsDatabase; +let dbGlobalCached: GelJsDatabase; +let cachedDb: GelJsDatabase; const tlsSecurity: string = 'insecure'; let dsn: string; let container: Docker.Container | undefined; @@ -101,6 +104,10 @@ declare module 'vitest' { gel: { db: GelJsDatabase; }; + cachedGel: { + db: GelJsDatabase; + dbGlobalCached: GelJsDatabase; + }; } } @@ -112,6 +119,12 @@ const usersTable = gelTable('users', { createdAt: timestamptz('created_at').notNull().defaultNow(), }); +const postsTable = gelTable('posts', { + id: integer().primaryKey(), + description: text().notNull(), + userId: integer('city_id').references(() => usersTable.id1), +}); + const usersOnUpdate = gelTable('users_on_update', { id1: integer('id1').notNull(), name: text('name').notNull(), @@ -222,6 +235,14 @@ beforeAll(async () => { }, }); db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { + logger: ENABLE_LOGGING, + cache: new TestCache(), + }); + dbGlobalCached = drizzle(client, { + logger: ENABLE_LOGGING, + cache: new TestGlobalCache(), + }); dsn = connectionString; }); @@ -235,9 +256,17 @@ beforeEach((ctx) => { ctx.gel = { db, }; + ctx.cachedGel = { + db: cachedDb, + dbGlobalCached, + }; }); describe('some', async () => { + beforeEach(async (ctx) => { + await ctx.cachedGel.db.$cache?.invalidate({ tables: 'users' }); + await ctx.cachedGel.dbGlobalCached.$cache?.invalidate({ tables: 'users' }); + }); beforeAll(async () => { await $`gel query "CREATE TYPE default::users { create property id1: int16 { @@ -4976,4 +5005,250 @@ describe('some', async () => { ); expect(inserted).toEqual([{ id1: 1, name: 'John' }]); }); + + test('test force invalidate', async (ctx) => { + const { db } = ctx.cachedGel; + + const spyInvalidate = vi.spyOn(db.$cache, 'invalidate'); + await db.$cache?.invalidate({ tables: 'users' }); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('default global config - no cache should be hit', async (ctx) => { + const { db } = ctx.cachedGel; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('default global config + enable cache on select: get, put', async (ctx) => { + const { db } = ctx.cachedGel; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('default global config + enable cache on select + write: get, put, onMutate', async (ctx) => { + const { db } = ctx.cachedGel; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ config: { ex: 1 } }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + spyPut.mockClear(); + spyGet.mockClear(); + spyInvalidate.mockClear(); + + await db.insert(usersTable).values({ id1: 1, name: 'John' }); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('default global config + enable cache on select + disable invalidate: get, put', async (ctx) => { + const { db } = ctx.cachedGel; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom', autoInvalidate: false, config: { ex: 1 } }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + await db.insert(usersTable).values({ id1: 1, name: 'John' }); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + test('global: true + disable cache', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedGel; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(false); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - cache should be hit', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedGel; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - cache: false on select - no cache hit', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedGel; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(false); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - disable invalidate - cache hit + no invalidate', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedGel; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ autoInvalidate: false }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + spyPut.mockClear(); + spyGet.mockClear(); + spyInvalidate.mockClear(); + + await db.insert(usersTable).values({ id1: 1, name: 'John' }); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('global: true - with custom tag', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedGel; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom', autoInvalidate: false }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + await db.insert(usersTable).values({ id1: 1, name: 'John' }); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + // check select used tables + test('check simple select used tables', (ctx) => { + const { db } = ctx.cachedGel; + + // @ts-expect-error + expect(db.select().from(usersTable).getUsedTables()).toStrictEqual(['users']); + // @ts-expect-error + expect(db.select().from(sql`${usersTable}`).getUsedTables()).toStrictEqual(['users']); + }); + // check select+join used tables + test('select+join', (ctx) => { + const { db } = ctx.cachedGel; + + // @ts-expect-error + expect(db.select().from(usersTable).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).getUsedTables()) + .toStrictEqual(['users', 'posts']); + expect( + // @ts-expect-error + db.select().from(sql`${usersTable}`).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).getUsedTables(), + ).toStrictEqual(['users', 'posts']); + }); + // check select+2join used tables + test('select+2joins', (ctx) => { + const { db } = ctx.cachedGel; + + expect( + db.select().from(usersTable).leftJoin( + postsTable, + eq(usersTable.id1, postsTable.userId), + ).leftJoin( + alias(postsTable, 'post2'), + eq(usersTable.id1, postsTable.userId), + ) + // @ts-expect-error + .getUsedTables(), + ) + .toStrictEqual(['users', 'posts']); + expect( + db.select().from(sql`${usersTable}`).leftJoin(postsTable, eq(usersTable.id1, postsTable.userId)).leftJoin( + alias(postsTable, 'post2'), + eq(usersTable.id1, postsTable.userId), + // @ts-expect-error + ).getUsedTables(), + ).toStrictEqual(['users', 'posts']); + }); + // select subquery used tables + test('select+join', (ctx) => { + const { db } = ctx.cachedGel; + + const sq = db.select().from(usersTable).where(eq(usersTable.id1, 42)).as('sq'); + + // @ts-expect-error + expect(db.select().from(sq).getUsedTables()).toStrictEqual(['users']); + }); }); diff --git a/integration-tests/tests/mysql/mysql-common-cache.ts b/integration-tests/tests/mysql/mysql-common-cache.ts new file mode 100644 index 0000000000..9a7a2f1d7c --- /dev/null +++ b/integration-tests/tests/mysql/mysql-common-cache.ts @@ -0,0 +1,379 @@ +import { eq, getTableName, is, sql, Table } from 'drizzle-orm'; +import type { MutationOption } from 'drizzle-orm/cache/core'; +import { Cache } from 'drizzle-orm/cache/core'; +import type { CacheConfig } from 'drizzle-orm/cache/core/types'; +import type { MySqlDatabase } from 'drizzle-orm/mysql-core'; +import { alias, boolean, int, json, mysqlTable, serial, text, timestamp } from 'drizzle-orm/mysql-core'; +import Keyv from 'keyv'; +import { beforeEach, describe, expect, test, vi } from 'vitest'; + +// eslint-disable-next-line drizzle-internal/require-entity-kind +export class TestGlobalCache extends Cache { + private globalTtl: number = 1000; + private usedTablesPerKey: Record = {}; + + constructor(private kv: Keyv = new Keyv()) { + super(); + } + + override strategy(): 'explicit' | 'all' { + return 'all'; + } + override async get(key: string, _tables: string[], _isTag: boolean): Promise { + const res = await this.kv.get(key) ?? undefined; + return res; + } + override async put( + key: string, + response: any, + tables: string[], + isTag: boolean, + config?: CacheConfig, + ): Promise { + await this.kv.set(key, response, config ? config.ex : this.globalTtl); + for (const table of tables) { + const keys = this.usedTablesPerKey[table]; + if (keys === undefined) { + this.usedTablesPerKey[table] = [key]; + } else { + keys.push(key); + } + } + } + override async onMutate(params: MutationOption): Promise { + const tagsArray = params.tags ? Array.isArray(params.tags) ? params.tags : [params.tags] : []; + const tablesArray = params.tables ? Array.isArray(params.tables) ? params.tables : [params.tables] : []; + + const keysToDelete = new Set(); + + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + const keys = this.usedTablesPerKey[tableName] ?? []; + for (const key of keys) keysToDelete.add(key); + } + + if (keysToDelete.size > 0 || tagsArray.length > 0) { + for (const tag of tagsArray) { + await this.kv.delete(tag); + } + + for (const key of keysToDelete) { + await this.kv.delete(key); + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + this.usedTablesPerKey[tableName] = []; + } + } + } + } +} + +// eslint-disable-next-line drizzle-internal/require-entity-kind +export class TestCache extends TestGlobalCache { + override strategy(): 'explicit' | 'all' { + return 'explicit'; + } +} + +declare module 'vitest' { + interface TestContext { + cachedMySQL: { + db: MySqlDatabase; + dbGlobalCached: MySqlDatabase; + }; + } +} + +const usersTable = mysqlTable('users', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + verified: boolean('verified').notNull().default(false), + jsonb: json('jsonb').$type(), + createdAt: timestamp('created_at', { fsp: 2 }).notNull().defaultNow(), +}); + +const postsTable = mysqlTable('posts', { + id: serial().primaryKey(), + description: text().notNull(), + userId: int('city_id').references(() => usersTable.id), +}); + +export function tests() { + describe('common_cache', () => { + beforeEach(async (ctx) => { + const { db, dbGlobalCached } = ctx.cachedMySQL; + await db.execute(sql`drop table if exists users`); + await db.execute(sql`drop table if exists posts`); + await db.$cache?.invalidate({ tables: 'users' }); + await dbGlobalCached.$cache?.invalidate({ tables: 'users' }); + // public users + await db.execute( + sql` + create table users ( + id serial primary key, + name text not null, + verified boolean not null default false, + jsonb json, + created_at timestamp not null default now() + ) + `, + ); + await db.execute( + sql` + create table posts ( + id serial primary key, + description text not null, + user_id int + ) + `, + ); + }); + + test('test force invalidate', async (ctx) => { + const { db } = ctx.cachedMySQL; + + const spyInvalidate = vi.spyOn(db.$cache, 'invalidate'); + await db.$cache?.invalidate({ tables: 'users' }); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('default global config - no cache should be hit', async (ctx) => { + const { db } = ctx.cachedMySQL; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('default global config + enable cache on select: get, put', async (ctx) => { + const { db } = ctx.cachedMySQL; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('default global config + enable cache on select + write: get, put, onMutate', async (ctx) => { + const { db } = ctx.cachedMySQL; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ config: { ex: 1 } }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + spyPut.mockClear(); + spyGet.mockClear(); + spyInvalidate.mockClear(); + + await db.insert(usersTable).values({ name: 'John' }); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('default global config + enable cache on select + disable invalidate: get, put', async (ctx) => { + const { db } = ctx.cachedMySQL; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom', autoInvalidate: false, config: { ex: 1 } }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + await db.insert(usersTable).values({ name: 'John' }); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + test('global: true + disable cache', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedMySQL; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(false); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - cache should be hit', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedMySQL; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - cache: false on select - no cache hit', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedMySQL; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(false); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - disable invalidate - cache hit + no invalidate', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedMySQL; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ autoInvalidate: false }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + spyPut.mockClear(); + spyGet.mockClear(); + spyInvalidate.mockClear(); + + await db.insert(usersTable).values({ name: 'John' }); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('global: true - with custom tag', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedMySQL; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom', autoInvalidate: false }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + await db.insert(usersTable).values({ name: 'John' }); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + // check select used tables + test('check simple select used tables', (ctx) => { + const { db } = ctx.cachedMySQL; + + // @ts-expect-error + expect(db.select().from(usersTable).getUsedTables()).toStrictEqual(['users']); + // @ts-expect-error + expect(db.select().from(sql`${usersTable}`).getUsedTables()).toStrictEqual(['users']); + }); + // check select+join used tables + test('select+join', (ctx) => { + const { db } = ctx.cachedMySQL; + + // @ts-expect-error + expect(db.select().from(usersTable).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).getUsedTables()) + .toStrictEqual(['users', 'posts']); + expect( + // @ts-expect-error + db.select().from(sql`${usersTable}`).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).getUsedTables(), + ).toStrictEqual(['users', 'posts']); + }); + // check select+2join used tables + test('select+2joins', (ctx) => { + const { db } = ctx.cachedMySQL; + + expect( + db.select().from(usersTable).leftJoin( + postsTable, + eq(usersTable.id, postsTable.userId), + ).leftJoin( + alias(postsTable, 'post2'), + eq(usersTable.id, postsTable.userId), + ) + // @ts-expect-error + .getUsedTables(), + ) + .toStrictEqual(['users', 'posts']); + expect( + db.select().from(sql`${usersTable}`).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).leftJoin( + alias(postsTable, 'post2'), + eq(usersTable.id, postsTable.userId), + // @ts-expect-error + ).getUsedTables(), + ).toStrictEqual(['users', 'posts']); + }); + // select subquery used tables + test('select+join', (ctx) => { + const { db } = ctx.cachedMySQL; + + const sq = db.select().from(usersTable).where(eq(usersTable.id, 42)).as('sq'); + db.select().from(sq); + + // @ts-expect-error + expect(db.select().from(sq).getUsedTables()).toStrictEqual(['users']); + }); + }); +} diff --git a/integration-tests/tests/mysql/mysql-common.ts b/integration-tests/tests/mysql/mysql-common.ts index 88050a9bec..9df31a2d57 100644 --- a/integration-tests/tests/mysql/mysql-common.ts +++ b/integration-tests/tests/mysql/mysql-common.ts @@ -281,15 +281,15 @@ export async function createDockerDB(): Promise<{ connectionString: string; cont return { connectionString: `mysql://root:mysql@127.0.0.1:${port}/drizzle`, container: mysqlContainer }; } -// afterAll(async () => { -// await mysqlContainer?.stop().catch(console.error); -// }); +afterAll(async () => { + await mysqlContainer?.stop().catch(console.error); +}); export function tests(driver?: string) { describe('common', () => { - afterAll(async () => { - await mysqlContainer?.stop().catch(console.error); - }); + // afterAll(async () => { + // await mysqlContainer?.stop().catch(console.error); + // }); beforeEach(async (ctx) => { const { db } = ctx.mysql; diff --git a/integration-tests/tests/mysql/mysql-planetscale.test.ts b/integration-tests/tests/mysql/mysql-planetscale.test.ts index 763b9c8e6e..fc97ea254a 100644 --- a/integration-tests/tests/mysql/mysql-planetscale.test.ts +++ b/integration-tests/tests/mysql/mysql-planetscale.test.ts @@ -4,19 +4,29 @@ import { drizzle } from 'drizzle-orm/planetscale-serverless'; import { beforeAll, beforeEach } from 'vitest'; import { skipTests } from '~/common'; import { tests } from './mysql-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './mysql-common-cache'; const ENABLE_LOGGING = false; let db: PlanetScaleDatabase; +let dbGlobalCached: PlanetScaleDatabase; +let cachedDb: PlanetScaleDatabase; beforeAll(async () => { - db = drizzle(new Client({ url: process.env['PLANETSCALE_CONNECTION_STRING']! }), { logger: ENABLE_LOGGING }); + const client = new Client({ url: process.env['PLANETSCALE_CONNECTION_STRING']! }); + db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestCache() }); + dbGlobalCached = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestGlobalCache() }); }); beforeEach((ctx) => { ctx.mysql = { db, }; + ctx.cachedMySQL = { + db: cachedDb, + dbGlobalCached, + }; }); skipTests([ @@ -75,3 +85,4 @@ skipTests([ ]); tests('planetscale'); +cacheTests(); diff --git a/integration-tests/tests/mysql/mysql.test.ts b/integration-tests/tests/mysql/mysql.test.ts index c2b73713cc..6641e2d14f 100644 --- a/integration-tests/tests/mysql/mysql.test.ts +++ b/integration-tests/tests/mysql/mysql.test.ts @@ -4,10 +4,13 @@ import { drizzle } from 'drizzle-orm/mysql2'; import * as mysql from 'mysql2/promise'; import { afterAll, beforeAll, beforeEach } from 'vitest'; import { createDockerDB, tests } from './mysql-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './mysql-common-cache'; const ENABLE_LOGGING = false; let db: MySql2Database; +let dbGlobalCached: MySql2Database; +let cachedDb: MySql2Database; let client: mysql.Connection; beforeAll(async () => { @@ -36,6 +39,8 @@ beforeAll(async () => { }, }); db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestCache() }); + dbGlobalCached = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestGlobalCache() }); }); afterAll(async () => { @@ -46,6 +51,11 @@ beforeEach((ctx) => { ctx.mysql = { db, }; + ctx.cachedMySQL = { + db: cachedDb, + dbGlobalCached, + }; }); +cacheTests(); tests(); diff --git a/integration-tests/tests/pg/neon-http-batch.test.ts b/integration-tests/tests/pg/neon-http-batch.test.ts index 2733ee7ef6..2c5a71bf6c 100644 --- a/integration-tests/tests/pg/neon-http-batch.test.ts +++ b/integration-tests/tests/pg/neon-http-batch.test.ts @@ -14,6 +14,7 @@ import { usersToGroupsConfig, usersToGroupsTable, } from './neon-http-batch'; +import { TestCache, TestGlobalCache } from './pg-common-cache'; const ENABLE_LOGGING = false; @@ -33,6 +34,8 @@ export const schema = { let db: NeonHttpDatabase; let client: NeonQueryFunction; +let dbGlobalCached: NeonHttpDatabase; +let cachedDb: NeonHttpDatabase; beforeAll(async () => { const connectionString = process.env['NEON_HTTP_CONNECTION_STRING']; @@ -41,12 +44,24 @@ beforeAll(async () => { } client = neon(connectionString); db = drizzle(client, { schema, logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { + logger: ENABLE_LOGGING, + cache: new TestCache(), + }); + dbGlobalCached = drizzle(client, { + logger: ENABLE_LOGGING, + cache: new TestGlobalCache(), + }); }); beforeEach((ctx) => { ctx.neonPg = { db, }; + ctx.cachedPg = { + db: cachedDb, + dbGlobalCached, + }; }); test('skip', async () => { diff --git a/integration-tests/tests/pg/neon-http.test.ts b/integration-tests/tests/pg/neon-http.test.ts index 93a7959a5f..863680316e 100644 --- a/integration-tests/tests/pg/neon-http.test.ts +++ b/integration-tests/tests/pg/neon-http.test.ts @@ -7,10 +7,13 @@ import { beforeAll, beforeEach, describe, expect, test, vi } from 'vitest'; import { skipTests } from '~/common'; import { randomString } from '~/utils'; import { tests, usersMigratorTable, usersTable } from './pg-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './pg-common-cache'; const ENABLE_LOGGING = false; let db: NeonHttpDatabase; +let dbGlobalCached: NeonHttpDatabase; +let cachedDb: NeonHttpDatabase; beforeAll(async () => { const connectionString = process.env['NEON_HTTP_CONNECTION_STRING']; @@ -22,13 +25,26 @@ beforeAll(async () => { const [protocol, port] = host === 'db.localtest.me' ? ['http', 4444] : ['https', 443]; return `${protocol}://${host}:${port}/sql`; }; - db = drizzle(neon(connectionString), { logger: ENABLE_LOGGING }); + const client = neon(connectionString); + db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { + logger: ENABLE_LOGGING, + cache: new TestCache(), + }); + dbGlobalCached = drizzle(client, { + logger: ENABLE_LOGGING, + cache: new TestGlobalCache(), + }); }); beforeEach((ctx) => { ctx.pg = { db, }; + ctx.cachedPg = { + db: cachedDb, + dbGlobalCached, + }; }); test('migrator : default migration strategy', async () => { @@ -416,6 +432,7 @@ skipTests([ 'test $onUpdateFn and $onUpdate works as $default', ]); tests(); +cacheTests(); beforeEach(async () => { await db.execute(sql`drop schema if exists public cascade`); diff --git a/integration-tests/tests/pg/neon-serverless.test.ts b/integration-tests/tests/pg/neon-serverless.test.ts index 864028177b..703f59696a 100644 --- a/integration-tests/tests/pg/neon-serverless.test.ts +++ b/integration-tests/tests/pg/neon-serverless.test.ts @@ -8,10 +8,13 @@ import ws from 'ws'; import { skipTests } from '~/common'; import { randomString } from '~/utils'; import { mySchema, tests, usersMigratorTable, usersMySchemaTable, usersTable } from './pg-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './pg-common-cache'; const ENABLE_LOGGING = false; let db: NeonDatabase; +let dbGlobalCached: NeonDatabase; +let cachedDb: NeonDatabase; let client: Pool; neonConfig.wsProxy = (host) => `${host}:5446/v1`; @@ -28,6 +31,14 @@ beforeAll(async () => { client = new Pool({ connectionString }); db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { + logger: ENABLE_LOGGING, + cache: new TestCache(), + }); + dbGlobalCached = drizzle(client, { + logger: ENABLE_LOGGING, + cache: new TestGlobalCache(), + }); }); afterAll(async () => { @@ -38,6 +49,10 @@ beforeEach((ctx) => { ctx.pg = { db, }; + ctx.cachedPg = { + db: cachedDb, + dbGlobalCached, + }; }); test('migrator : default migration strategy', async () => { @@ -503,6 +518,7 @@ skipTests([ 'mySchema :: delete with returning all fields', ]); tests(); +cacheTests(); beforeEach(async () => { await db.execute(sql`drop schema if exists public cascade`); diff --git a/integration-tests/tests/pg/node-postgres.test.ts b/integration-tests/tests/pg/node-postgres.test.ts index 076f6ddb45..a7921ad189 100644 --- a/integration-tests/tests/pg/node-postgres.test.ts +++ b/integration-tests/tests/pg/node-postgres.test.ts @@ -9,11 +9,14 @@ import { afterAll, beforeAll, beforeEach, expect, test } from 'vitest'; import { skipTests } from '~/common'; import { randomString } from '~/utils'; import { createDockerDB, tests, usersMigratorTable, usersTable } from './pg-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './pg-common-cache'; const ENABLE_LOGGING = false; let db: NodePgDatabase; let client: Client; +let dbGlobalCached: NodePgDatabase; +let cachedDb: NodePgDatabase; beforeAll(async () => { let connectionString; @@ -38,6 +41,8 @@ beforeAll(async () => { }, }); db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestCache() }); + dbGlobalCached = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestGlobalCache() }); }); afterAll(async () => { @@ -48,6 +53,10 @@ beforeEach((ctx) => { ctx.pg = { db, }; + ctx.cachedPg = { + db: cachedDb, + dbGlobalCached, + }; }); test('migrator : default migration strategy', async () => { @@ -432,6 +441,7 @@ skipTests([ 'test mode string for timestamp with timezone in different timezone', ]); tests(); +cacheTests(); beforeEach(async () => { await db.execute(sql`drop schema if exists public cascade`); diff --git a/integration-tests/tests/pg/pg-common-cache.ts b/integration-tests/tests/pg/pg-common-cache.ts new file mode 100644 index 0000000000..3d34e43d91 --- /dev/null +++ b/integration-tests/tests/pg/pg-common-cache.ts @@ -0,0 +1,400 @@ +import type Docker from 'dockerode'; +import { eq, getTableName, is, sql, Table } from 'drizzle-orm'; +import type { MutationOption } from 'drizzle-orm/cache/core'; +import { Cache } from 'drizzle-orm/cache/core'; +import type { CacheConfig } from 'drizzle-orm/cache/core/types'; +import type { PgDatabase, PgQueryResultHKT } from 'drizzle-orm/pg-core'; +import { alias, boolean, integer, jsonb, pgTable, serial, text, timestamp } from 'drizzle-orm/pg-core'; +import Keyv from 'keyv'; +import { afterAll, beforeEach, describe, expect, test, vi } from 'vitest'; + +// eslint-disable-next-line drizzle-internal/require-entity-kind +export class TestGlobalCache extends Cache { + private globalTtl: number = 1000; + private usedTablesPerKey: Record = {}; + + constructor(private kv: Keyv = new Keyv()) { + super(); + } + + override strategy(): 'explicit' | 'all' { + return 'all'; + } + override async get(key: string, _tables: string[], _isTag: boolean): Promise { + const res = await this.kv.get(key) ?? undefined; + return res; + } + override async put( + key: string, + response: any, + tables: string[], + isTag: boolean, + config?: CacheConfig, + ): Promise { + await this.kv.set(key, response, config ? config.ex : this.globalTtl); + for (const table of tables) { + const keys = this.usedTablesPerKey[table]; + if (keys === undefined) { + this.usedTablesPerKey[table] = [key]; + } else { + keys.push(key); + } + } + } + override async onMutate(params: MutationOption): Promise { + const tagsArray = params.tags ? Array.isArray(params.tags) ? params.tags : [params.tags] : []; + const tablesArray = params.tables ? Array.isArray(params.tables) ? params.tables : [params.tables] : []; + + const keysToDelete = new Set(); + + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + const keys = this.usedTablesPerKey[tableName] ?? []; + for (const key of keys) keysToDelete.add(key); + } + + if (keysToDelete.size > 0 || tagsArray.length > 0) { + for (const tag of tagsArray) { + await this.kv.delete(tag); + } + + for (const key of keysToDelete) { + await this.kv.delete(key); + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + this.usedTablesPerKey[tableName] = []; + } + } + } + } +} + +// eslint-disable-next-line drizzle-internal/require-entity-kind +export class TestCache extends TestGlobalCache { + override strategy(): 'explicit' | 'all' { + return 'explicit'; + } +} + +declare module 'vitest' { + interface TestContext { + cachedPg: { + db: PgDatabase; + dbGlobalCached: PgDatabase; + }; + } +} + +const usersTable = pgTable('users', { + id: serial().primaryKey(), + name: text().notNull(), + verified: boolean().notNull().default(false), + jsonb: jsonb().$type(), + createdAt: timestamp('created_at', { withTimezone: true }).notNull().defaultNow(), +}); + +const postsTable = pgTable('posts', { + id: serial().primaryKey(), + description: text().notNull(), + userId: integer('city_id').references(() => usersTable.id), +}); + +let pgContainer: Docker.Container; + +afterAll(async () => { + await pgContainer?.stop().catch(console.error); +}); + +export function tests() { + describe('common', () => { + beforeEach(async (ctx) => { + const { db, dbGlobalCached } = ctx.cachedPg; + await db.execute(sql`drop schema if exists public cascade`); + await db.$cache?.invalidate({ tables: 'users' }); + await dbGlobalCached.$cache?.invalidate({ tables: 'users' }); + await db.execute(sql`create schema public`); + // public users + await db.execute( + sql` + create table users ( + id serial primary key, + name text not null, + verified boolean not null default false, + jsonb jsonb, + created_at timestamptz not null default now() + ) + `, + ); + }); + + test('test force invalidate', async (ctx) => { + const { db } = ctx.cachedPg; + + const spyInvalidate = vi.spyOn(db.$cache, 'invalidate'); + await db.$cache?.invalidate({ tables: 'users' }); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('default global config - no cache should be hit', async (ctx) => { + const { db } = ctx.cachedPg; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('default global config + enable cache on select: get, put', async (ctx) => { + const { db } = ctx.cachedPg; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('default global config + enable cache on select + write: get, put, onMutate', async (ctx) => { + const { db } = ctx.cachedPg; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ config: { ex: 1 } }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + spyPut.mockClear(); + spyGet.mockClear(); + spyInvalidate.mockClear(); + + await db.insert(usersTable).values({ name: 'John' }); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('default global config + enable cache on select + disable invalidate: get, put', async (ctx) => { + const { db } = ctx.cachedPg; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom', autoInvalidate: false, config: { ex: 1 } }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + await db.insert(usersTable).values({ name: 'John' }); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + test('global: true + disable cache', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedPg; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(false); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - cache should be hit', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedPg; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - cache: false on select - no cache hit', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedPg; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(false); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - disable invalidate - cache hit + no invalidate', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedPg; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ autoInvalidate: false }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + spyPut.mockClear(); + spyGet.mockClear(); + spyInvalidate.mockClear(); + + await db.insert(usersTable).values({ name: 'John' }); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('global: true - with custom tag', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedPg; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom', autoInvalidate: false }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + await db.insert(usersTable).values({ name: 'John' }); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + test('global: true - with custom tag + with autoinvalidate', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedPg; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom' }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + + await db.insert(usersTable).values({ name: 'John' }); + + expect(spyInvalidate).toHaveBeenCalledTimes(1); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + // check select used tables + test('check simple select used tables', (ctx) => { + const { db } = ctx.cachedPg; + + // @ts-expect-error + expect(db.select().from(usersTable).getUsedTables()).toStrictEqual(['users']); + // @ts-expect-error + expect(db.select().from(sql`${usersTable}`).getUsedTables()).toStrictEqual(['users']); + }); + // check select+join used tables + test('select+join', (ctx) => { + const { db } = ctx.cachedPg; + + // @ts-expect-error + expect(db.select().from(usersTable).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).getUsedTables()) + .toStrictEqual(['users', 'posts']); + expect( + // @ts-expect-error + db.select().from(sql`${usersTable}`).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).getUsedTables(), + ).toStrictEqual(['users', 'posts']); + }); + // check select+2join used tables + test('select+2joins', (ctx) => { + const { db } = ctx.cachedPg; + + expect( + db.select().from(usersTable).leftJoin( + postsTable, + eq(usersTable.id, postsTable.userId), + ).leftJoin( + alias(postsTable, 'post2'), + eq(usersTable.id, postsTable.userId), + ) + // @ts-expect-error + .getUsedTables(), + ) + .toStrictEqual(['users', 'posts']); + expect( + db.select().from(sql`${usersTable}`).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).leftJoin( + alias(postsTable, 'post2'), + eq(usersTable.id, postsTable.userId), + // @ts-expect-error + ).getUsedTables(), + ).toStrictEqual(['users', 'posts']); + }); + // select subquery used tables + test('select+join', (ctx) => { + const { db } = ctx.cachedPg; + + const sq = db.select().from(usersTable).where(eq(usersTable.id, 42)).as('sq'); + db.select().from(sq); + + // @ts-expect-error + expect(db.select().from(sq).getUsedTables()).toStrictEqual(['users']); + }); + }); +} diff --git a/integration-tests/tests/pg/pg-proxy.test.ts b/integration-tests/tests/pg/pg-proxy.test.ts index 54f7c57668..19aa41cb75 100644 --- a/integration-tests/tests/pg/pg-proxy.test.ts +++ b/integration-tests/tests/pg/pg-proxy.test.ts @@ -8,6 +8,7 @@ import * as pg from 'pg'; import { afterAll, beforeAll, beforeEach, expect, test } from 'vitest'; import { skipTests } from '~/common'; import { createDockerDB, tests, usersMigratorTable, usersTable } from './pg-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './pg-common-cache'; // eslint-disable-next-line drizzle-internal/require-entity-kind class ServerSimulator { @@ -73,6 +74,8 @@ class ServerSimulator { const ENABLE_LOGGING = false; let db: PgRemoteDatabase; +let dbGlobalCached: PgRemoteDatabase; +let cachedDb: PgRemoteDatabase; let client: pg.Client; let serverSimulator: ServerSimulator; @@ -99,7 +102,7 @@ beforeAll(async () => { }, }); serverSimulator = new ServerSimulator(client); - db = proxyDrizzle(async (sql, params, method) => { + const proxyHandler = async (sql: string, params: any[], method: any) => { try { const response = await serverSimulator.query(sql, params, method); @@ -112,9 +115,13 @@ beforeAll(async () => { console.error('Error from pg proxy server:', e.message); throw e; } - }, { + }; + db = proxyDrizzle(proxyHandler, { logger: ENABLE_LOGGING, }); + + cachedDb = proxyDrizzle(proxyHandler, { logger: ENABLE_LOGGING, cache: new TestCache() }); + dbGlobalCached = proxyDrizzle(proxyHandler, { logger: ENABLE_LOGGING, cache: new TestGlobalCache() }); }); afterAll(async () => { @@ -125,6 +132,10 @@ beforeEach((ctx) => { ctx.pg = { db, }; + ctx.cachedPg = { + db: cachedDb, + dbGlobalCached, + }; }); test('migrator : default migration strategy', async () => { @@ -498,3 +509,4 @@ test('insert via db.execute w/ query builder', async () => { }); tests(); +cacheTests(); diff --git a/integration-tests/tests/pg/pglite.test.ts b/integration-tests/tests/pg/pglite.test.ts index 2a3d6a2fff..560b24490b 100644 --- a/integration-tests/tests/pg/pglite.test.ts +++ b/integration-tests/tests/pg/pglite.test.ts @@ -5,15 +5,26 @@ import { migrate } from 'drizzle-orm/pglite/migrator'; import { afterAll, beforeAll, beforeEach, expect, test } from 'vitest'; import { skipTests } from '~/common'; import { tests, usersMigratorTable, usersTable } from './pg-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './pg-common-cache'; const ENABLE_LOGGING = false; let db: PgliteDatabase; +let dbGlobalCached: PgliteDatabase; +let cachedDb: PgliteDatabase; let client: PGlite; beforeAll(async () => { client = new PGlite(); db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { + logger: ENABLE_LOGGING, + cache: new TestCache(), + }); + dbGlobalCached = drizzle(client, { + logger: ENABLE_LOGGING, + cache: new TestGlobalCache(), + }); }); afterAll(async () => { @@ -24,11 +35,17 @@ beforeEach((ctx) => { ctx.pg = { db, }; + ctx.cachedPg = { + db: cachedDb, + dbGlobalCached, + }; }); test('migrator : default migration strategy', async () => { await db.execute(sql`drop table if exists all_columns`); - await db.execute(sql`drop table if exists users12`); + await db.execute( + sql`drop table if exists users12`, + ); await db.execute(sql`drop table if exists "drizzle"."__drizzle_migrations"`); await migrate(db, { migrationsFolder: './drizzle2/pg' }); @@ -91,7 +108,9 @@ skipTests([ 'select with group by as column + sql', 'mySchema :: select with group by as column + sql', ]); + tests(); +cacheTests(); beforeEach(async () => { await db.execute(sql`drop schema if exists public cascade`); @@ -101,7 +120,7 @@ beforeEach(async () => { create table users ( id serial primary key, name text not null, - verified boolean not null default false, + verified boolean not null default false, jsonb jsonb, created_at timestamptz not null default now() ) diff --git a/integration-tests/tests/pg/postgres-js.test.ts b/integration-tests/tests/pg/postgres-js.test.ts index 14effc39c5..ee48d3eb42 100644 --- a/integration-tests/tests/pg/postgres-js.test.ts +++ b/integration-tests/tests/pg/postgres-js.test.ts @@ -10,10 +10,13 @@ import { migrate } from 'drizzle-orm/postgres-js/migrator'; import { skipTests } from '~/common'; import { randomString } from '~/utils'; import { createDockerDB, tests, usersMigratorTable, usersTable } from './pg-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './pg-common-cache'; const ENABLE_LOGGING = false; let db: PostgresJsDatabase; +let dbGlobalCached: PostgresJsDatabase; +let cachedDb: PostgresJsDatabase; let client: Sql; beforeAll(async () => { @@ -44,6 +47,8 @@ beforeAll(async () => { }, }); db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestCache() }); + dbGlobalCached = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestGlobalCache() }); }); afterAll(async () => { @@ -54,6 +59,10 @@ beforeEach((ctx) => { ctx.pg = { db, }; + ctx.cachedPg = { + db: cachedDb, + dbGlobalCached, + }; }); test('migrator : default migration strategy', async () => { @@ -439,6 +448,7 @@ skipTests([ ]); tests(); +cacheTests(); beforeEach(async () => { await db.execute(sql`drop schema if exists public cascade`); diff --git a/integration-tests/tests/pg/vercel-pg.test.ts b/integration-tests/tests/pg/vercel-pg.test.ts index ecf1d22ac3..3499b0eed8 100644 --- a/integration-tests/tests/pg/vercel-pg.test.ts +++ b/integration-tests/tests/pg/vercel-pg.test.ts @@ -6,11 +6,14 @@ import { migrate } from 'drizzle-orm/vercel-postgres/migrator'; import { afterAll, beforeAll, beforeEach, expect, test } from 'vitest'; import { skipTests } from '~/common'; import { randomString } from '~/utils'; -import { createDockerDB, tests, usersMigratorTable, usersTable } from './pg-common'; +import { createDockerDB, tests, tests as cacheTests, usersMigratorTable, usersTable } from './pg-common'; +import { TestCache, TestGlobalCache } from './pg-common-cache'; const ENABLE_LOGGING = false; let db: VercelPgDatabase; +let dbGlobalCached: VercelPgDatabase; +let cachedDb: VercelPgDatabase; let client: VercelClient; beforeAll(async () => { @@ -46,6 +49,8 @@ beforeAll(async () => { throw lastError; } db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestCache() }); + dbGlobalCached = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestGlobalCache() }); }); afterAll(async () => { @@ -56,6 +61,10 @@ beforeEach((ctx) => { ctx.pg = { db, }; + ctx.cachedPg = { + db: cachedDb, + dbGlobalCached, + }; }); test('migrator : default migration strategy', async () => { @@ -440,6 +449,7 @@ skipTests([ 'select from tables with same name from different schema using alias', // ]); tests(); +cacheTests(); beforeEach(async () => { await db.execute(sql`drop schema if exists public cascade`); diff --git a/integration-tests/tests/pg/xata-http.test.ts b/integration-tests/tests/pg/xata-http.test.ts index 80c97e7659..2fd35072d9 100644 --- a/integration-tests/tests/pg/xata-http.test.ts +++ b/integration-tests/tests/pg/xata-http.test.ts @@ -8,11 +8,14 @@ import { beforeAll, beforeEach, expect, test } from 'vitest'; import { skipTests } from '~/common'; import { randomString } from '~/utils'; import { getXataClient } from '../xata/xata.ts'; -import { tests, usersMigratorTable, usersTable } from './pg-common'; +import { tests, tests as cacheTests, usersMigratorTable, usersTable } from './pg-common'; +import { TestCache, TestGlobalCache } from './pg-common-cache.ts'; const ENABLE_LOGGING = false; let db: XataHttpDatabase; +let dbGlobalCached: XataHttpDatabase; +let cachedDb: XataHttpDatabase; let client: XataHttpClient; beforeAll(async () => { @@ -32,12 +35,18 @@ beforeAll(async () => { randomize: false, }); db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestCache() }); + dbGlobalCached = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestGlobalCache() }); }); beforeEach((ctx) => { ctx.pg = { db, }; + ctx.cachedPg = { + db: cachedDb, + dbGlobalCached, + }; }); test('migrator : default migration strategy', async () => { @@ -375,6 +384,7 @@ skipTests([ 'subquery with view', ]); tests(); +cacheTests(); beforeEach(async () => { await db.execute(sql`drop schema if exists public cascade`); diff --git a/integration-tests/tests/singlestore/singlestore-cache.ts b/integration-tests/tests/singlestore/singlestore-cache.ts new file mode 100644 index 0000000000..992849aa8a --- /dev/null +++ b/integration-tests/tests/singlestore/singlestore-cache.ts @@ -0,0 +1,390 @@ +import { eq, getTableName, is, sql, Table } from 'drizzle-orm'; +import type { MutationOption } from 'drizzle-orm/cache/core'; +import { Cache } from 'drizzle-orm/cache/core'; +import type { CacheConfig } from 'drizzle-orm/cache/core/types'; +import { + alias, + boolean, + int, + json, + serial, + type SingleStoreDatabase, + singlestoreTable, + text, + timestamp, +} from 'drizzle-orm/singlestore-core'; +import Keyv from 'keyv'; +import { beforeEach, describe, expect, test, vi } from 'vitest'; + +// eslint-disable-next-line drizzle-internal/require-entity-kind +export class TestGlobalCache extends Cache { + private globalTtl: number = 1000; + private usedTablesPerKey: Record = {}; + + constructor(private kv: Keyv = new Keyv()) { + super(); + } + + override strategy(): 'explicit' | 'all' { + return 'all'; + } + override async get(key: string, _tables: string[], _isTag: boolean): Promise { + const res = await this.kv.get(key) ?? undefined; + return res; + } + override async put( + key: string, + response: any, + tables: string[], + isTag: boolean, + config?: CacheConfig, + ): Promise { + await this.kv.set(key, response, config ? config.ex : this.globalTtl); + for (const table of tables) { + const keys = this.usedTablesPerKey[table]; + if (keys === undefined) { + this.usedTablesPerKey[table] = [key]; + } else { + keys.push(key); + } + } + } + override async onMutate(params: MutationOption): Promise { + const tagsArray = params.tags ? Array.isArray(params.tags) ? params.tags : [params.tags] : []; + const tablesArray = params.tables ? Array.isArray(params.tables) ? params.tables : [params.tables] : []; + + const keysToDelete = new Set(); + + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + const keys = this.usedTablesPerKey[tableName] ?? []; + for (const key of keys) keysToDelete.add(key); + } + + if (keysToDelete.size > 0 || tagsArray.length > 0) { + for (const tag of tagsArray) { + await this.kv.delete(tag); + } + + for (const key of keysToDelete) { + await this.kv.delete(key); + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + this.usedTablesPerKey[tableName] = []; + } + } + } + } +} + +// eslint-disable-next-line drizzle-internal/require-entity-kind +export class TestCache extends TestGlobalCache { + override strategy(): 'explicit' | 'all' { + return 'explicit'; + } +} + +type TestSingleStoreDB = SingleStoreDatabase; + +declare module 'vitest' { + interface TestContext { + cachedSingleStore: { + db: TestSingleStoreDB; + dbGlobalCached: TestSingleStoreDB; + }; + } +} + +const usersTable = singlestoreTable('users', { + id: serial('id').primaryKey(), + name: text('name').notNull(), + verified: boolean('verified').notNull().default(false), + jsonb: json('jsonb').$type(), + createdAt: timestamp('created_at').notNull().defaultNow(), +}); + +const postsTable = singlestoreTable('posts', { + id: serial().primaryKey(), + description: text().notNull(), + userId: int('city_id'), +}); + +export function tests() { + describe('common_cache', () => { + beforeEach(async (ctx) => { + const { db, dbGlobalCached } = ctx.cachedSingleStore; + await db.execute(sql`drop table if exists users`); + await db.execute(sql`drop table if exists posts`); + await db.$cache?.invalidate({ tables: 'users' }); + await dbGlobalCached.$cache?.invalidate({ tables: 'users' }); + // public users + await db.execute( + sql` + create table users ( + id serial primary key, + name text not null, + verified boolean not null default false, + jsonb json, + created_at timestamp not null default now() + ) + `, + ); + await db.execute( + sql` + create table posts ( + id serial primary key, + description text not null, + user_id int + ) + `, + ); + }); + + test('test force invalidate', async (ctx) => { + const { db } = ctx.cachedSingleStore; + + const spyInvalidate = vi.spyOn(db.$cache, 'invalidate'); + await db.$cache?.invalidate({ tables: 'users' }); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('default global config - no cache should be hit', async (ctx) => { + const { db } = ctx.cachedSingleStore; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('default global config + enable cache on select: get, put', async (ctx) => { + const { db } = ctx.cachedSingleStore; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('default global config + enable cache on select + write: get, put, onMutate', async (ctx) => { + const { db } = ctx.cachedSingleStore; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ config: { ex: 1 } }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + spyPut.mockClear(); + spyGet.mockClear(); + spyInvalidate.mockClear(); + + await db.insert(usersTable).values({ name: 'John' }); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('default global config + enable cache on select + disable invalidate: get, put', async (ctx) => { + const { db } = ctx.cachedSingleStore; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom', autoInvalidate: false, config: { ex: 1 } }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + await db.insert(usersTable).values({ name: 'John' }); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + test('global: true + disable cache', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedSingleStore; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(false); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - cache should be hit', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedSingleStore; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - cache: false on select - no cache hit', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedSingleStore; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(false); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - disable invalidate - cache hit + no invalidate', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedSingleStore; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ autoInvalidate: false }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + spyPut.mockClear(); + spyGet.mockClear(); + spyInvalidate.mockClear(); + + await db.insert(usersTable).values({ name: 'John' }); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('global: true - with custom tag', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedSingleStore; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom', autoInvalidate: false }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + await db.insert(usersTable).values({ name: 'John' }); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + // check select used tables + test('check simple select used tables', (ctx) => { + const { db } = ctx.cachedSingleStore; + + // @ts-expect-error + expect(db.select().from(usersTable).getUsedTables()).toStrictEqual(['users']); + // @ts-expect-error + expect(db.select().from(sql`${usersTable}`).getUsedTables()).toStrictEqual(['users']); + }); + // check select+join used tables + test('select+join', (ctx) => { + const { db } = ctx.cachedSingleStore; + + // @ts-expect-error + expect(db.select().from(usersTable).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).getUsedTables()) + .toStrictEqual(['users', 'posts']); + expect( + // @ts-expect-error + db.select().from(sql`${usersTable}`).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).getUsedTables(), + ).toStrictEqual(['users', 'posts']); + }); + // check select+2join used tables + test('select+2joins', (ctx) => { + const { db } = ctx.cachedSingleStore; + + expect( + db.select().from(usersTable).leftJoin( + postsTable, + eq(usersTable.id, postsTable.userId), + ).leftJoin( + alias(postsTable, 'post2'), + eq(usersTable.id, postsTable.userId), + ) + // @ts-expect-error + .getUsedTables(), + ) + .toStrictEqual(['users', 'posts']); + expect( + db.select().from(sql`${usersTable}`).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).leftJoin( + alias(postsTable, 'post2'), + eq(usersTable.id, postsTable.userId), + // @ts-expect-error + ).getUsedTables(), + ).toStrictEqual(['users', 'posts']); + }); + // select subquery used tables + test('select+join', (ctx) => { + const { db } = ctx.cachedSingleStore; + + const sq = db.select().from(usersTable).where(eq(usersTable.id, 42)).as('sq'); + db.select().from(sq); + + // @ts-expect-error + expect(db.select().from(sq).getUsedTables()).toStrictEqual(['users']); + }); + }); +} diff --git a/integration-tests/tests/singlestore/singlestore.test.ts b/integration-tests/tests/singlestore/singlestore.test.ts index 36ac1989cb..8a14bb17c5 100644 --- a/integration-tests/tests/singlestore/singlestore.test.ts +++ b/integration-tests/tests/singlestore/singlestore.test.ts @@ -3,11 +3,14 @@ import { drizzle } from 'drizzle-orm/singlestore'; import type { SingleStoreDriverDatabase } from 'drizzle-orm/singlestore'; import * as mysql2 from 'mysql2/promise'; import { afterAll, beforeAll, beforeEach } from 'vitest'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './singlestore-cache'; import { createDockerDB, tests } from './singlestore-common'; const ENABLE_LOGGING = false; let db: SingleStoreDriverDatabase; +let dbGlobalCached: SingleStoreDriverDatabase; +let cachedDb: SingleStoreDriverDatabase; let client: mysql2.Connection; beforeAll(async () => { @@ -36,6 +39,8 @@ beforeAll(async () => { await client.query(`CREATE DATABASE IF NOT EXISTS drizzle;`); await client.changeUser({ database: 'drizzle' }); db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestCache() }); + dbGlobalCached = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestGlobalCache() }); }); afterAll(async () => { @@ -46,6 +51,11 @@ beforeEach((ctx) => { ctx.singlestore = { db, }; + ctx.cachedSingleStore = { + db: cachedDb, + dbGlobalCached, + }; }); +cacheTests(); tests(); diff --git a/integration-tests/tests/sqlite/d1.test.ts b/integration-tests/tests/sqlite/d1.test.ts index 7928b74245..6878b127f6 100644 --- a/integration-tests/tests/sqlite/d1.test.ts +++ b/integration-tests/tests/sqlite/d1.test.ts @@ -8,21 +8,30 @@ import { beforeAll, beforeEach, expect, test } from 'vitest'; import { skipTests } from '~/common'; import { randomString } from '~/utils'; import { anotherUsersMigratorTable, tests, usersMigratorTable } from './sqlite-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './sqlite-common-cache'; const ENABLE_LOGGING = false; let db: DrizzleD1Database; +let dbGlobalCached: DrizzleD1Database; +let cachedDb: DrizzleD1Database; beforeAll(async () => { const sqliteDb = await createSQLiteDB(':memory:'); const d1db = new D1Database(new D1DatabaseAPI(sqliteDb)); db = drizzle(d1db, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(d1db, { logger: ENABLE_LOGGING, cache: new TestCache() }); + dbGlobalCached = drizzle(d1db, { logger: ENABLE_LOGGING, cache: new TestGlobalCache() }); }); beforeEach((ctx) => { ctx.sqlite = { db, }; + ctx.cachedSqlite = { + db: cachedDb, + dbGlobalCached, + }; }); test('migrator', async () => { @@ -87,4 +96,5 @@ skipTests([ 'join view as subquery', 'cross join', ]); +cacheTests(); tests(); diff --git a/integration-tests/tests/sqlite/libsql.test.ts b/integration-tests/tests/sqlite/libsql.test.ts index b99d7e9bf4..8d68496584 100644 --- a/integration-tests/tests/sqlite/libsql.test.ts +++ b/integration-tests/tests/sqlite/libsql.test.ts @@ -7,10 +7,13 @@ import { afterAll, beforeAll, beforeEach, expect, test } from 'vitest'; import { skipTests } from '~/common'; import { randomString } from '~/utils'; import { anotherUsersMigratorTable, tests, usersMigratorTable } from './sqlite-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './sqlite-common-cache'; const ENABLE_LOGGING = false; let db: LibSQLDatabase; +let dbGlobalCached: LibSQLDatabase; +let cachedDb: LibSQLDatabase; let client: Client; beforeAll(async () => { @@ -33,6 +36,8 @@ beforeAll(async () => { }, }); db = drizzle(client, { logger: ENABLE_LOGGING }); + cachedDb = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestCache() }); + dbGlobalCached = drizzle(client, { logger: ENABLE_LOGGING, cache: new TestGlobalCache() }); }); afterAll(async () => { @@ -43,6 +48,10 @@ beforeEach((ctx) => { ctx.sqlite = { db, }; + ctx.cachedSqlite = { + db: cachedDb, + dbGlobalCached, + }; }); test('migrator', async () => { @@ -93,4 +102,5 @@ skipTests([ 'update with limit and order by', ]); +cacheTests(); tests(); diff --git a/integration-tests/tests/sqlite/sqlite-common-cache.ts b/integration-tests/tests/sqlite/sqlite-common-cache.ts new file mode 100644 index 0000000000..05582055f6 --- /dev/null +++ b/integration-tests/tests/sqlite/sqlite-common-cache.ts @@ -0,0 +1,380 @@ +import { eq, getTableName, is, sql, Table } from 'drizzle-orm'; +import type { MutationOption } from 'drizzle-orm/cache/core'; +import { Cache } from 'drizzle-orm/cache/core'; +import type { CacheConfig } from 'drizzle-orm/cache/core/types'; +import { alias, type BaseSQLiteDatabase, integer, sqliteTable, text } from 'drizzle-orm/sqlite-core'; +import Keyv from 'keyv'; +import { beforeEach, describe, expect, test, vi } from 'vitest'; + +// eslint-disable-next-line drizzle-internal/require-entity-kind +export class TestGlobalCache extends Cache { + private globalTtl: number = 1000; + private usedTablesPerKey: Record = {}; + + constructor(private kv: Keyv = new Keyv()) { + super(); + } + + override strategy(): 'explicit' | 'all' { + return 'all'; + } + override async get(key: string, _tables: string[], _isTag: boolean): Promise { + const res = await this.kv.get(key) ?? undefined; + return res; + } + override async put( + key: string, + response: any, + tables: string[], + isTag: boolean, + config?: CacheConfig, + ): Promise { + await this.kv.set(key, response, config ? config.ex : this.globalTtl); + for (const table of tables) { + const keys = this.usedTablesPerKey[table]; + if (keys === undefined) { + this.usedTablesPerKey[table] = [key]; + } else { + keys.push(key); + } + } + } + override async onMutate(params: MutationOption): Promise { + const tagsArray = params.tags ? Array.isArray(params.tags) ? params.tags : [params.tags] : []; + const tablesArray = params.tables ? Array.isArray(params.tables) ? params.tables : [params.tables] : []; + + const keysToDelete = new Set(); + + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + const keys = this.usedTablesPerKey[tableName] ?? []; + for (const key of keys) keysToDelete.add(key); + } + + if (keysToDelete.size > 0 || tagsArray.length > 0) { + for (const tag of tagsArray) { + await this.kv.delete(tag); + } + + for (const key of keysToDelete) { + await this.kv.delete(key); + for (const table of tablesArray) { + const tableName = is(table, Table) ? getTableName(table) : table as string; + this.usedTablesPerKey[tableName] = []; + } + } + } + } +} + +// eslint-disable-next-line drizzle-internal/require-entity-kind +export class TestCache extends TestGlobalCache { + override strategy(): 'explicit' | 'all' { + return 'explicit'; + } +} + +declare module 'vitest' { + interface TestContext { + cachedSqlite: { + db: BaseSQLiteDatabase; + dbGlobalCached: BaseSQLiteDatabase; + }; + sqlite: { + db: BaseSQLiteDatabase<'async' | 'sync', any, Record>; + }; + } +} + +const usersTable = sqliteTable('users', { + id: integer('id').primaryKey({ autoIncrement: true }), + name: text('name').notNull(), + verified: integer('verified', { mode: 'boolean' }).notNull().default(false), + jsonb: text('jsonb', { mode: 'json' }).$type(), + createdAt: integer('created_at', { mode: 'timestamp' }), +}); + +const postsTable = sqliteTable('posts', { + id: integer().primaryKey({ autoIncrement: true }), + description: text().notNull(), + userId: integer('user_id').references(() => usersTable.id), +}); + +export function tests() { + describe('common_cache', () => { + beforeEach(async (ctx) => { + const { db, dbGlobalCached } = ctx.cachedSqlite; + await db.run(sql`drop table if exists users`); + await db.run(sql`drop table if exists posts`); + await db.$cache?.invalidate({ tables: 'users' }); + await dbGlobalCached.$cache?.invalidate({ tables: 'users' }); + // public users + await db.run( + sql` + create table users ( + id integer primary key AUTOINCREMENT, + name text not null, + verified integer not null default 0, + jsonb text, + created_at integer + ) + `, + ); + await db.run( + sql` + create table posts ( + id integer primary key AUTOINCREMENT, + description text not null, + user_id int + ) + `, + ); + }); + + test('test force invalidate', async (ctx) => { + const { db } = ctx.cachedSqlite; + + const spyInvalidate = vi.spyOn(db.$cache, 'invalidate'); + await db.$cache?.invalidate({ tables: 'users' }); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('default global config - no cache should be hit', async (ctx) => { + const { db } = ctx.cachedSqlite; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('default global config + enable cache on select: get, put', async (ctx) => { + const { db } = ctx.cachedSqlite; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('default global config + enable cache on select + write: get, put, onMutate', async (ctx) => { + const { db } = ctx.cachedSqlite; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ config: { ex: 1 } }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + spyPut.mockClear(); + spyGet.mockClear(); + spyInvalidate.mockClear(); + + await db.insert(usersTable).values({ name: 'John' }); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('default global config + enable cache on select + disable invalidate: get, put', async (ctx) => { + const { db } = ctx.cachedSqlite; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom', autoInvalidate: false, config: { ex: 1 } }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + await db.insert(usersTable).values({ name: 'John' }); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + test('global: true + disable cache', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedSqlite; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(false); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - cache should be hit', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedSqlite; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - cache: false on select - no cache hit', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedSqlite; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache(false); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + }); + + test('global: true - disable invalidate - cache hit + no invalidate', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedSqlite; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ autoInvalidate: false }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + spyPut.mockClear(); + spyGet.mockClear(); + spyInvalidate.mockClear(); + + await db.insert(usersTable).values({ name: 'John' }); + + expect(spyPut).toHaveBeenCalledTimes(0); + expect(spyGet).toHaveBeenCalledTimes(0); + expect(spyInvalidate).toHaveBeenCalledTimes(1); + }); + + test('global: true - with custom tag', async (ctx) => { + const { dbGlobalCached: db } = ctx.cachedSqlite; + + // @ts-expect-error + const spyPut = vi.spyOn(db.$cache, 'put'); + // @ts-expect-error + const spyGet = vi.spyOn(db.$cache, 'get'); + // @ts-expect-error + const spyInvalidate = vi.spyOn(db.$cache, 'onMutate'); + + await db.select().from(usersTable).$withCache({ tag: 'custom', autoInvalidate: false }); + + expect(spyPut).toHaveBeenCalledTimes(1); + expect(spyGet).toHaveBeenCalledTimes(1); + expect(spyInvalidate).toHaveBeenCalledTimes(0); + + await db.insert(usersTable).values({ name: 'John' }); + + // invalidate force + await db.$cache?.invalidate({ tags: ['custom'] }); + }); + + // check select used tables + test('check simple select used tables', (ctx) => { + const { db } = ctx.cachedSqlite; + + // @ts-expect-error + expect(db.select().from(usersTable).getUsedTables()).toStrictEqual(['users']); + // @ts-expect-error + expect(db.select().from(sql`${usersTable}`).getUsedTables()).toStrictEqual(['users']); + }); + // check select+join used tables + test('select+join', (ctx) => { + const { db } = ctx.cachedSqlite; + + // @ts-expect-error + expect(db.select().from(usersTable).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).getUsedTables()) + .toStrictEqual(['users', 'posts']); + expect( + // @ts-expect-error + db.select().from(sql`${usersTable}`).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).getUsedTables(), + ).toStrictEqual(['users', 'posts']); + }); + // check select+2join used tables + test('select+2joins', (ctx) => { + const { db } = ctx.cachedSqlite; + + expect( + db.select().from(usersTable).leftJoin( + postsTable, + eq(usersTable.id, postsTable.userId), + ).leftJoin( + alias(postsTable, 'post2'), + eq(usersTable.id, postsTable.userId), + ) + // @ts-expect-error + .getUsedTables(), + ) + .toStrictEqual(['users', 'posts']); + expect( + db.select().from(sql`${usersTable}`).leftJoin(postsTable, eq(usersTable.id, postsTable.userId)).leftJoin( + alias(postsTable, 'post2'), + eq(usersTable.id, postsTable.userId), + // @ts-expect-error + ).getUsedTables(), + ).toStrictEqual(['users', 'posts']); + }); + // select subquery used tables + test('select+join', (ctx) => { + const { db } = ctx.cachedSqlite; + + const sq = db.select().from(usersTable).where(eq(usersTable.id, 42)).as('sq'); + + // @ts-expect-error + expect(db.select().from(sq).getUsedTables()).toStrictEqual(['users']); + }); + }); +} diff --git a/integration-tests/tests/sqlite/sqlite-proxy.test.ts b/integration-tests/tests/sqlite/sqlite-proxy.test.ts index 2aec14be5c..23024cf9af 100644 --- a/integration-tests/tests/sqlite/sqlite-proxy.test.ts +++ b/integration-tests/tests/sqlite/sqlite-proxy.test.ts @@ -7,6 +7,7 @@ import { drizzle as proxyDrizzle } from 'drizzle-orm/sqlite-proxy'; import { afterAll, beforeAll, beforeEach, expect, test } from 'vitest'; import { skipTests } from '~/common'; import { tests, usersTable } from './sqlite-common'; +import { TestCache, TestGlobalCache, tests as cacheTests } from './sqlite-common-cache'; class ServerSimulator { constructor(private db: BetterSqlite3.Database) {} @@ -54,6 +55,8 @@ class ServerSimulator { } let db: SqliteRemoteDatabase; +let dbGlobalCached: SqliteRemoteDatabase; +let cachedDb: SqliteRemoteDatabase; let client: Database.Database; let serverSimulator: ServerSimulator; @@ -62,7 +65,7 @@ beforeAll(async () => { client = new Database(dbPath); serverSimulator = new ServerSimulator(client); - db = proxyDrizzle(async (sql, params, method) => { + const callback = async (sql: string, params: any[], method: string) => { try { const rows = await serverSimulator.query(sql, params, method); @@ -75,13 +78,20 @@ beforeAll(async () => { console.error('Error from sqlite proxy server:', e.response?.data ?? e.message); throw e; } - }); + }; + db = proxyDrizzle(callback); + cachedDb = proxyDrizzle(callback, { cache: new TestCache() }); + dbGlobalCached = proxyDrizzle(callback, { cache: new TestGlobalCache() }); }); beforeEach((ctx) => { ctx.sqlite = { db, }; + ctx.cachedSqlite = { + db: cachedDb, + dbGlobalCached, + }; }); afterAll(async () => { @@ -95,6 +105,7 @@ skipTests([ 'insert via db.get', 'insert via db.run + select via db.all', ]); +cacheTests(); tests(); beforeEach(async () => { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 32577e6ba0..c8e4aca999 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -382,8 +382,8 @@ importers: specifier: ^0.1.1 version: 0.1.1 '@types/better-sqlite3': - specifier: ^7.6.4 - version: 7.6.10 + specifier: ^7.6.12 + version: 7.6.13 '@types/node': specifier: ^20.2.5 version: 20.12.12 @@ -396,6 +396,9 @@ importers: '@types/sql.js': specifier: ^1.4.4 version: 1.4.9 + '@upstash/redis': + specifier: ^1.34.3 + version: 1.34.9 '@vercel/postgres': specifier: ^0.8.0 version: 0.8.0 @@ -833,6 +836,9 @@ importers: '@types/ws': specifier: ^8.5.10 version: 8.5.11 + '@upstash/redis': + specifier: ^1.34.3 + version: 1.34.9 '@vitest/ui': specifier: ^1.6.0 version: 1.6.0(vitest@3.1.3) @@ -845,6 +851,9 @@ importers: import-in-the-middle: specifier: ^1.13.1 version: 1.13.1 + keyv: + specifier: ^5.2.3 + version: 5.3.3 ts-node: specifier: ^10.9.2 version: 10.9.2(@types/node@20.12.12)(typescript@5.6.3) @@ -3526,6 +3535,9 @@ packages: peerDependencies: jsep: ^0.4.0||^1.0.0 + '@keyv/serialize@1.0.3': + resolution: {integrity: sha512-qnEovoOp5Np2JDGonIDL6Ayihw0RhnRh6vxPuHo4RDn1UOzwEo4AeIfpL6UGIrsceWrCMiVPgwRjbHu4vYFc3g==} + '@libsql/client-wasm@0.10.0': resolution: {integrity: sha512-xSlpGdBGEr4mRtjCnDejTqtDpct2ng8cqHUQs+S4xG1yv0h+hLdzOtQJSY9JV9T/2MWWDfdCiEntPs2SdErSJA==} bundledDependencies: @@ -4771,6 +4783,9 @@ packages: '@ungap/structured-clone@1.2.0': resolution: {integrity: sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==} + '@upstash/redis@1.34.9': + resolution: {integrity: sha512-7qzzF2FQP5VxR2YUNjemWs+hl/8VzJJ6fOkT7O7kt9Ct8olEVzb1g6/ik6B8Pb8W7ZmYv81SdlVV9F6O8bh/gw==} + '@urql/core@2.3.6': resolution: {integrity: sha512-PUxhtBh7/8167HJK6WqBv6Z0piuiaZHQGYbhwpNL9aIQmLROPEdaUYkY4wh45wPQXcTpnd11l0q3Pw+TI11pdw==} peerDependencies: @@ -5224,6 +5239,9 @@ packages: buffer@5.7.1: resolution: {integrity: sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==} + buffer@6.0.3: + resolution: {integrity: sha512-FTiCpNxtwiZZHEZbcbTIcZjERVICn9yq/pDFkTl95/AxzD1naBctN7YO68riM/gLSDY7sdrMby8hofADYuuqOA==} + bufferutil@4.0.8: resolution: {integrity: sha512-4T53u4PdgsXqKaIctwF8ifXlRTTmEPJ8iEPWFdGZvcf7sbwYo6FKFEX9eNNAnzFZ7EzJAQ3CJeOtCRA4rDp7Pw==} engines: {node: '>=6.14.2'} @@ -5670,6 +5688,9 @@ packages: crypt@0.0.2: resolution: {integrity: sha512-mCxBlsHFYh9C+HVpiEacem8FEBnMXgU9gy4zmNC+SXAZNB/1idgp/aulFJ4FgCi7GPEVbfyng092GqL2k2rmow==} + crypto-js@4.2.0: + resolution: {integrity: sha512-KALDyEYgpY+Rlob/iriUtjV6d5Eq+Y191A5g4UqLAi8CyGP9N1+FdVbkc1SxKc2r4YAYqG8JzO2KGL+AizD70Q==} + crypto-random-string@1.0.0: resolution: {integrity: sha512-GsVpkFPlycH7/fRR7Dhcmnoii54gV1nz7y4CWyeFS14N+JVBBhY+r8amRHE4BwSYal7BPTDp8isvAlCxyFt3Hg==} engines: {node: '>=4'} @@ -7637,6 +7658,9 @@ packages: keyv@4.5.3: resolution: {integrity: sha512-QCiSav9WaX1PgETJ+SpNnx2PRRapJ/oRSXM4VO5OGYGSjrxbKPVFVhB3l2OCbLCk329N8qyAtsJjSjvVBWzEug==} + keyv@5.3.3: + resolution: {integrity: sha512-Rwu4+nXI9fqcxiEHtbkvoes2X+QfkTRo1TMkPfwzipGsJlJO/z69vqB4FNl9xJ3xCpAcbkvmEabZfPzrwN3+gQ==} + kind-of@6.0.3: resolution: {integrity: sha512-dcS1ul+9tmeD95T+x28/ehLgd9mENa3LsvDTtzm3vyBEO7RPptvAD+t44WVXaUjTBRcrpFeFlC8WCruUR456hw==} engines: {node: '>=0.10.0'} @@ -13984,6 +14008,10 @@ snapshots: dependencies: jsep: 1.4.0 + '@keyv/serialize@1.0.3': + dependencies: + buffer: 6.0.3 + '@libsql/client-wasm@0.10.0': dependencies: '@libsql/core': 0.10.0 @@ -15677,6 +15705,10 @@ snapshots: '@ungap/structured-clone@1.2.0': {} + '@upstash/redis@1.34.9': + dependencies: + crypto-js: 4.2.0 + '@urql/core@2.3.6(graphql@15.8.0)': dependencies: '@graphql-typed-document-node/core': 3.2.0(graphql@15.8.0) @@ -16276,6 +16308,11 @@ snapshots: base64-js: 1.5.1 ieee754: 1.2.1 + buffer@6.0.3: + dependencies: + base64-js: 1.5.1 + ieee754: 1.2.1 + bufferutil@4.0.8: dependencies: node-gyp-build: 4.8.1 @@ -16772,6 +16809,8 @@ snapshots: crypt@0.0.2: {} + crypto-js@4.2.0: {} + crypto-random-string@1.0.0: {} crypto-random-string@2.0.0: {} @@ -19183,6 +19222,10 @@ snapshots: dependencies: json-buffer: 3.0.1 + keyv@5.3.3: + dependencies: + '@keyv/serialize': 1.0.3 + kind-of@6.0.3: {} kleur@3.0.3: {}