diff --git a/milvus/MilvusClient.ts b/milvus/MilvusClient.ts index 91975158..45cdf855 100644 --- a/milvus/MilvusClient.ts +++ b/milvus/MilvusClient.ts @@ -12,6 +12,15 @@ import { CreateCollectionReq, ERROR_REASONS, checkCreateCollectionCompatibility, + SearchReq, + SearchSimpleReq, + HybridSearchReq, + SearchIteratorReq, + QueryReq, + QueryIteratorReq, + GetReq, + SearchResults, + QueryResults, DEFAULT_PRIMARY_KEY_FIELD, DEFAULT_METRIC_TYPE, DEFAULT_VECTOR_FIELD, @@ -74,6 +83,19 @@ export class MilvusClient extends GRPCClient { } } + /** + * Creates a lightweight DQL session pinned to a target cluster. + * The session reuses this client's connection and injects `cluster_id` into + * search/query/get/iterator request parameters. + * @param clusterId The target cluster id. + */ + session(clusterId: string): MilvusClientSession { + if (typeof clusterId !== 'string' || clusterId.length === 0) { + throw new Error('clusterId must be a non-empty string'); + } + return new MilvusClientSession(this, clusterId); + } + // High level API: align with python MilvusClient /** * Creates a new collection with the given parameters. @@ -236,3 +258,67 @@ export class MilvusClient extends GRPCClient { return result; } } + +/** + * Lightweight DQL session bound to a cluster id and backed by a parent client. + * Closing the session only prevents future session calls; it does not close the + * parent client or its gRPC channel pool. + */ +export class MilvusClientSession { + private readonly parent: MilvusClient; + private readonly clusterId: string; + private closed = false; + + constructor(parent: MilvusClient, clusterId: string) { + this.parent = parent; + this.clusterId = clusterId; + } + + private ensureOpen() { + if (this.closed) { + throw new Error('MilvusClient session is closed'); + } + } + + private withClusterId(params: T): T { + return { ...params, cluster_id: this.clusterId }; + } + + close(): void { + this.closed = true; + } + + search( + params: T + ): Promise> { + this.ensureOpen(); + return this.parent.search(this.withClusterId(params)); + } + + hybridSearch( + params: T + ): Promise> { + this.ensureOpen(); + return this.parent.hybridSearch(this.withClusterId(params)); + } + + searchIterator(param: SearchIteratorReq): Promise { + this.ensureOpen(); + return this.parent.searchIterator(this.withClusterId(param)); + } + + query(data: QueryReq): Promise { + this.ensureOpen(); + return this.parent.query(this.withClusterId(data)); + } + + queryIterator(data: QueryIteratorReq): Promise { + this.ensureOpen(); + return this.parent.queryIterator(this.withClusterId(data)); + } + + get(data: GetReq): Promise { + this.ensureOpen(); + return this.parent.get(this.withClusterId(data)); + } +} diff --git a/milvus/const/defaults.ts b/milvus/const/defaults.ts index 7978bd05..5b178be2 100644 --- a/milvus/const/defaults.ts +++ b/milvus/const/defaults.ts @@ -12,6 +12,7 @@ export const DEFAULT_RESOURCE_GROUP = '__default_resource_group'; // default res export const DEFAULT_DB = 'default'; // default database name export const DEFAULT_DYNAMIC_FIELD = '$meta'; // default dynamic field name export const DEFAULT_COUNT_QUERY_STRING = 'count(*)'; // default count query string +export const CLUSTER_ID = 'cluster_id'; // cluster id routing parameter for DQL requests export const DEFAULT_HTTP_TIMEOUT = 60000; // default http timeout, 60s export const DEFAULT_HTTP_ENDPOINT_VERSION = 'v2'; // api version, default v1 diff --git a/milvus/grpc/Data.ts b/milvus/grpc/Data.ts index d11e1493..2237af44 100644 --- a/milvus/grpc/Data.ts +++ b/milvus/grpc/Data.ts @@ -81,6 +81,7 @@ import { FloatVector, FieldPartialUpdateOpType, FieldPartialUpdateOp, + CLUSTER_ID, } from '../'; import { Collection } from './Collection'; @@ -842,11 +843,14 @@ export class Data extends Collection { const count = await client.count({ collection_name: param.collection_name, expr: param.expr || param.filter || '', + db_name: param.db_name, + cluster_id: param.cluster_id, }); // get collection Info const collectionInfo = await this.describeCollection({ collection_name: param.collection_name, + db_name: param.db_name, }); // if limit not set, set it to count @@ -874,6 +878,7 @@ export class Data extends Collection { // search iterator special params const params: any = { ...param.params, + ...(param.cluster_id ? { [CLUSTER_ID]: param.cluster_id } : {}), [ITERATOR_FIELD]: true, [ITER_SEARCH_V2_KEY]: true, [ITER_SEARCH_BATCH_SIZE_KEY]: batchSize, @@ -956,6 +961,8 @@ export class Data extends Collection { const count = await client.count({ collection_name: data.collection_name, expr: userExpr, + db_name: data.db_name, + cluster_id: data.cluster_id, }); // remove filter field to avoid conflict with expr in query method const queryData = { ...data }; @@ -1186,6 +1193,10 @@ export class Data extends Collection { } // Execute the query and get the results + if (data.cluster_id) { + queryParams[CLUSTER_ID] = data.cluster_id; + } + const promise: QueryRes = await promisify( this.channelPool, 'Query', @@ -1264,13 +1275,16 @@ export class Data extends Collection { async count(data: CountReq): Promise { const req: any = { collection_name: data.collection_name, - expr: data.expr || '', + expr: data.filter || data.expr || '', output_fields: [DEFAULT_COUNT_QUERY_STRING], }; if (data.db_name) { req.db_name = data.db_name; } + if (data.cluster_id) { + req.cluster_id = data.cluster_id; + } const queryResult = await this.query(req); return { diff --git a/milvus/proto-json/milvus.base.ts b/milvus/proto-json/milvus.base.ts index 4de14459..725e6a55 100644 --- a/milvus/proto-json/milvus.base.ts +++ b/milvus/proto-json/milvus.base.ts @@ -6513,6 +6513,10 @@ export default { "endTime": { "type": "int64", "id": 8 + }, + "externalSpec": { + "type": "string", + "id": 9 } } }, diff --git a/milvus/proto-json/milvus.ts b/milvus/proto-json/milvus.ts index ef968ff3..6014ae36 100644 --- a/milvus/proto-json/milvus.ts +++ b/milvus/proto-json/milvus.ts @@ -6513,6 +6513,10 @@ export default { "end_time": { "type": "int64", "id": 8 + }, + "external_spec": { + "type": "string", + "id": 9 } } }, diff --git a/milvus/types/Common.ts b/milvus/types/Common.ts index fb29f9b4..efa45a9b 100644 --- a/milvus/types/Common.ts +++ b/milvus/types/Common.ts @@ -34,6 +34,7 @@ export interface GrpcTimeOut { timeout?: number; client_request_id?: string; // optional, trace id for request tracking 'client-request-id'?: string; // optional, trace id for request tracking (alternative format) + cluster_id?: string; // optional, route DQL request to a specific cluster } export type PrivilegesTypes = | CollectionPrivileges diff --git a/milvus/types/Data.ts b/milvus/types/Data.ts index c4a35078..ce863f0e 100644 --- a/milvus/types/Data.ts +++ b/milvus/types/Data.ts @@ -13,6 +13,7 @@ import { export interface CountReq extends collectionNameReq { expr?: string; // filter expression + filter?: string; // alias for expr } interface BaseDeleteReq extends collectionNameReq { @@ -114,6 +115,8 @@ export interface GetReq extends collectionNameReq { offset?: number; // skip how many results limit?: number; // how many results you want consistency_level?: ConsistencyLevelEnum; // consistency level + transformers?: OutputTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors + exprValues?: keyValueObj; // template values for filter expression, eg: {key: 'value'} } export interface QueryRes extends resStatusResponse { diff --git a/milvus/types/Search.ts b/milvus/types/Search.ts index d4b1906d..a30df5d0 100644 --- a/milvus/types/Search.ts +++ b/milvus/types/Search.ts @@ -14,6 +14,7 @@ import { Int8Vector, FieldData, OrderByFields, + GrpcTimeOut, } from '../'; // Highlighter types @@ -132,13 +133,14 @@ export interface SearchSimpleReq extends collectionNameReq { export type HybridSearchSingleReq = Pick< SearchParam, 'anns_field' | 'ignore_growing' | 'group_by_field' -> & { - data: SearchData; // vector to search - expr?: string; // filter expression - exprValues?: keyValueObj; // template values for filter expression, eg: {key: 'value'} - params?: keyValueObj; // extra search parameters - transformers?: OutputTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors -}; +> & + GrpcTimeOut & { + data: SearchData; // vector to search + expr?: string; // filter expression + exprValues?: keyValueObj; // template values for filter expression, eg: {key: 'value'} + params?: keyValueObj; // extra search parameters + transformers?: OutputTransformers; // provide custom data transformer for specific data type like bf16 or f16 vectors + }; export interface SearchIteratorReq extends Omit< SearchSimpleReq, diff --git a/milvus/utils/Search.ts b/milvus/utils/Search.ts index 91fdad14..cf3143c3 100644 --- a/milvus/utils/Search.ts +++ b/milvus/utils/Search.ts @@ -22,6 +22,7 @@ import { keyValueObj, FunctionObject, FunctionScore, + CLUSTER_ID, buildFieldDataMap, cloneObj, parseToKeyValue, @@ -418,6 +419,13 @@ export const buildSearchRequest = ( ); } + if (!isHybridSearch && params.cluster_id) { + request.search_params = [ + ...(request.search_params as KeyValuePair[]), + { key: CLUSTER_ID, value: params.cluster_id }, + ]; + } + // if exprValues is set, add it to the request(inner) if (userRequest.exprValues) { request.expr_template_values = formatExprValues(userRequest.exprValues); @@ -450,6 +458,9 @@ export const buildSearchRequest = ( 'type' in rerank ); const hasFunctionScore = isFunctionScore(rerank); + const clusterIdParam = params.cluster_id + ? [{ key: CLUSTER_ID, value: params.cluster_id }] + : []; // build highlighter if provided const highlighter = @@ -489,6 +500,7 @@ export const buildSearchRequest = ( key: 'offset', value: searchSimpleReq.offset ?? 0, }, + ...clusterIdParam, ], }, diff --git a/proto b/proto index 966f0620..445ffe01 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 966f06209e2b02700aacd50a337623ac69e6559b +Subproject commit 445ffe015a5d2e9e27798bd33ba94815d2fb99fa diff --git a/test/grpc/MilvusClient.spec.ts b/test/grpc/MilvusClient.spec.ts index 2ab806dc..d98ac144 100644 --- a/test/grpc/MilvusClient.spec.ts +++ b/test/grpc/MilvusClient.spec.ts @@ -2,6 +2,7 @@ import path from 'path'; import { readFileSync } from 'fs'; import { MilvusClient, + MilvusClientSession, ERROR_REASONS, CONNECT_STATUS, TLS_MODE, @@ -203,9 +204,9 @@ describe(`Milvus client`, () => { expect(client.channelOptions['grpc.keepalive_time_ms']).toEqual(10000); expect(client.channelOptions['grpc.keepalive_timeout_ms']).toEqual(5000); - expect(client.channelOptions['grpc.keepalive_permit_without_calls']).toEqual( - 1 - ); + expect( + client.channelOptions['grpc.keepalive_permit_without_calls'] + ).toEqual(1); }); it(`should allow overriding keepalive options via channelOptions`, async () => { @@ -220,9 +221,9 @@ describe(`Milvus client`, () => { expect(client.channelOptions['grpc.keepalive_time_ms']).toEqual(20000); expect(client.channelOptions['grpc.keepalive_timeout_ms']).toEqual(8000); - expect(client.channelOptions['grpc.keepalive_permit_without_calls']).toEqual( - 1 - ); + expect( + client.channelOptions['grpc.keepalive_permit_without_calls'] + ).toEqual(1); }); it(`should add trace interceptor if enableTrace is true`, async () => { @@ -248,9 +249,7 @@ describe(`Milvus client`, () => { __SKIP_CONNECT__: true, }); // Should have request metadata interceptor (adds client-request-unixmsec) - expect(client.channelOptions.interceptors.length).toBeGreaterThanOrEqual( - 3 - ); + expect(client.channelOptions.interceptors.length).toBeGreaterThanOrEqual(3); }); it(`Expect get node sdk info`, async () => { @@ -258,6 +257,92 @@ describe(`Milvus client`, () => { expect(MilvusClient.sdkInfo.recommendMilvus).toEqual(sdkInfo.milvusVersion); }); + it('should create a session that injects cluster_id into DQL methods', async () => { + const client = new MilvusClient({ + address: IP, + __SKIP_CONNECT__: true, + }); + const session = client.session('in07-xxx'); + const expected = { status: { error_code: 'Success', reason: '' } }; + + const searchSpy = jest + .spyOn(client, 'search') + .mockResolvedValue(expected as any); + const hybridSearchSpy = jest + .spyOn(client, 'hybridSearch') + .mockResolvedValue(expected as any); + const querySpy = jest + .spyOn(client, 'query') + .mockResolvedValue(expected as any); + const getSpy = jest.spyOn(client, 'get').mockResolvedValue(expected as any); + const searchIteratorSpy = jest + .spyOn(client, 'searchIterator') + .mockResolvedValue(expected as any); + const queryIteratorSpy = jest + .spyOn(client, 'queryIterator') + .mockResolvedValue(expected as any); + + await session.search({ + collection_name: 'col', + data: [0.1, 0.2], + cluster_id: 'ignored', + }); + await session.hybridSearch({ + collection_name: 'col', + data: [{ data: [0.1, 0.2], anns_field: 'vector' }], + }); + await session.query({ collection_name: 'col', filter: 'id > 0' }); + await session.get({ collection_name: 'col', ids: [1, 2] }); + await session.searchIterator({ + collection_name: 'col', + data: [0.1, 0.2], + batchSize: 10, + }); + await session.queryIterator({ + collection_name: 'col', + filter: 'id > 0', + batchSize: 10, + }); + + expect(searchSpy).toHaveBeenCalledWith( + expect.objectContaining({ cluster_id: 'in07-xxx' }) + ); + expect(hybridSearchSpy).toHaveBeenCalledWith( + expect.objectContaining({ cluster_id: 'in07-xxx' }) + ); + expect(querySpy).toHaveBeenCalledWith( + expect.objectContaining({ cluster_id: 'in07-xxx' }) + ); + expect(getSpy).toHaveBeenCalledWith( + expect.objectContaining({ cluster_id: 'in07-xxx' }) + ); + expect(searchIteratorSpy).toHaveBeenCalledWith( + expect.objectContaining({ cluster_id: 'in07-xxx' }) + ); + expect(queryIteratorSpy).toHaveBeenCalledWith( + expect.objectContaining({ cluster_id: 'in07-xxx' }) + ); + + session.close(); + expect(() => + session.query({ collection_name: 'col', filter: 'id > 0' }) + ).toThrow('session is closed'); + }); + + it('should reject an invalid cluster id when creating a session', () => { + const client = new MilvusClient({ + address: IP, + __SKIP_CONNECT__: true, + }); + + expect(() => client.session('')).toThrow( + 'clusterId must be a non-empty string' + ); + expect(() => client.session(undefined as any)).toThrow( + 'clusterId must be a non-empty string' + ); + }); + it(`Get milvus version`, async () => { const res = await milvusClient.getVersion(); expect(res).toHaveProperty('version'); diff --git a/test/utils/Data.spec.ts b/test/utils/Data.spec.ts index 6c3b4696..0498234b 100644 --- a/test/utils/Data.spec.ts +++ b/test/utils/Data.spec.ts @@ -15,6 +15,7 @@ import { processVectorData, findKeyValue, FieldPartialUpdateOpType, + CLUSTER_ID, } from '../../milvus'; describe('utils/Data', () => { @@ -43,6 +44,7 @@ describe('utils/Data', () => { limit: 10, offset: 2, order_by: ['price:asc', { field: 'rating', order: 'desc' }], + cluster_id: 'in07-xxx', }); expect(findKeyValue(queryParams.query_params, 'limit')).toBe(10); @@ -50,6 +52,141 @@ describe('utils/Data', () => { expect(findKeyValue(queryParams.query_params, 'order_by_fields')).toBe( 'price:asc,rating:desc' ); + expect(findKeyValue(queryParams.query_params, CLUSTER_ID)).toBe('in07-xxx'); + }); + + it('should forward cluster_id through count query requests', async () => { + const client = new MilvusClient({ + address: 'localhost:19530', + __SKIP_CONNECT__: true, + }); + const querySpy = jest.spyOn(client, 'query').mockResolvedValue({ + status: { error_code: ErrorCode.SUCCESS, reason: '' }, + data: [{ 'count(*)': '7' }], + } as any); + + const result = await client.count({ + collection_name: 'test_collection', + filter: 'id > 0', + db_name: 'db1', + cluster_id: 'in07-xxx', + }); + + expect(result.data).toBe(7); + expect(querySpy).toHaveBeenCalledWith( + expect.objectContaining({ + collection_name: 'test_collection', + expr: 'id > 0', + db_name: 'db1', + cluster_id: 'in07-xxx', + output_fields: ['count(*)'], + }) + ); + }); + + it('should forward cluster_id through search iterator requests', async () => { + const client = new MilvusClient({ + address: 'localhost:19530', + __SKIP_CONNECT__: true, + }); + const countSpy = jest.spyOn(client, 'count').mockResolvedValue({ + status: { error_code: ErrorCode.SUCCESS, reason: '' }, + data: 1, + } as any); + const describeSpy = jest + .spyOn(client, 'describeCollection') + .mockResolvedValue({ + status: { error_code: ErrorCode.SUCCESS, reason: '' }, + collectionID: '100', + } as any); + const searchSpy = jest.spyOn(client, 'search').mockResolvedValue({ + status: { error_code: ErrorCode.SUCCESS, reason: '' }, + results: [{ id: '1' }], + search_iterator_v2_results: { token: 'token', last_bound: 'bound' }, + session_ts: 123, + } as any); + + const iterator = await client.searchIterator({ + collection_name: 'test_collection', + expr: 'id > 0', + filter: 'ignored', + data: [0.1, 0.2], + batchSize: 10, + db_name: 'db1', + cluster_id: 'in07-xxx', + }); + const page = await iterator[Symbol.asyncIterator]().next(); + + expect(page.value).toEqual([{ id: '1' }]); + expect(countSpy).toHaveBeenCalledWith( + expect.objectContaining({ + collection_name: 'test_collection', + expr: 'id > 0', + db_name: 'db1', + cluster_id: 'in07-xxx', + }) + ); + expect(describeSpy).toHaveBeenCalledWith( + expect.objectContaining({ + collection_name: 'test_collection', + db_name: 'db1', + }) + ); + expect(searchSpy).toHaveBeenCalledWith( + expect.objectContaining({ + collection_name: 'test_collection', + cluster_id: 'in07-xxx', + params: expect.objectContaining({ + cluster_id: 'in07-xxx', + collection_id: '100', + }), + }) + ); + }); + + it('should forward cluster_id through query iterator count and query requests', async () => { + const client = new MilvusClient({ + address: 'localhost:19530', + __SKIP_CONNECT__: true, + }); + const countSpy = jest.spyOn(client, 'count').mockResolvedValue({ + status: { error_code: ErrorCode.SUCCESS, reason: '' }, + data: 1, + } as any); + const querySpy = jest.spyOn(client, 'query').mockResolvedValue({ + status: { error_code: ErrorCode.SUCCESS, reason: '' }, + data: [{ id: '1' }], + } as any); + jest.spyOn(client, 'getPkField').mockResolvedValue({ + name: 'id', + data_type: 'VarChar', + } as any); + + const iterator = await client.queryIterator({ + collection_name: 'test_collection', + filter: 'score > 0', + batchSize: 10, + db_name: 'db1', + cluster_id: 'in07-xxx', + }); + const page = await iterator[Symbol.asyncIterator]().next(); + + expect(page.value).toEqual([{ id: '1' }]); + expect(countSpy).toHaveBeenCalledWith( + expect.objectContaining({ + collection_name: 'test_collection', + expr: 'score > 0', + db_name: 'db1', + cluster_id: 'in07-xxx', + }) + ); + expect(querySpy).toHaveBeenCalledWith( + expect.objectContaining({ + collection_name: 'test_collection', + db_name: 'db1', + cluster_id: 'in07-xxx', + }) + ); }); const captureUpsertParams = async ( diff --git a/test/utils/Search.spec.ts b/test/utils/Search.spec.ts index 5a4335ce..1aac0e9a 100644 --- a/test/utils/Search.spec.ts +++ b/test/utils/Search.spec.ts @@ -14,6 +14,9 @@ import { buildSearchParams, PlaceholderType, buildPlaceholderGroupBytes, + CLUSTER_ID, + findKeyValue, + RRFRanker, } from '../../milvus'; describe('utils/Search', () => { it('should build single search request correctly', () => { @@ -122,6 +125,104 @@ describe('utils/Search', () => { expect(searchParamsKeyValuePairObject.ignore_growing).toEqual(false); }); + it('should add cluster_id to search params for search request', () => { + const milvusProtoPath = path.resolve( + __dirname, + '../../proto/proto/milvus.proto' + ); + const milvusProto = protobuf.loadSync(milvusProtoPath); + const describeCollectionResponse = { + status: { error_code: 'Success', reason: '' }, + collection_name: 'test', + collectionID: 0, + consistency_level: 'Session', + schema: { + fields: [ + { name: 'id', dataType: DataType.Int64, is_primary_key: true }, + { + name: 'vector', + dataType: DataType.FloatVector, + data_type: 'FloatVector', + type_params: [{ key: 'dim', value: '3' }], + index_params: [], + }, + ], + }, + anns_fields: { + vector: { + dataType: DataType.FloatVector, + data_type: 'FloatVector', + type_params: [{ key: 'dim', value: '3' }], + index_params: [], + }, + }, + } as any; + + const result = buildSearchRequest( + { + collection_name: 'test', + data: [1, 2, 3], + cluster_id: 'in07-xxx', + }, + describeCollectionResponse, + milvusProto + ); + + expect( + findKeyValue((result.request as any).search_params, CLUSTER_ID) + ).toBe('in07-xxx'); + }); + + it('should add cluster_id to rank params for hybrid search request', () => { + const milvusProtoPath = path.resolve( + __dirname, + '../../proto/proto/milvus.proto' + ); + const milvusProto = protobuf.loadSync(milvusProtoPath); + const describeCollectionResponse = { + status: { error_code: 'Success', reason: '' }, + collection_name: 'test', + collectionID: 0, + consistency_level: 'Session', + schema: { + fields: [ + { name: 'id', dataType: DataType.Int64, is_primary_key: true }, + { + name: 'vector', + dataType: DataType.FloatVector, + data_type: 'FloatVector', + type_params: [{ key: 'dim', value: '3' }], + index_params: [], + }, + ], + }, + anns_fields: { + vector: { + dataType: DataType.FloatVector, + data_type: 'FloatVector', + type_params: [{ key: 'dim', value: '3' }], + index_params: [], + }, + }, + } as any; + + const result = buildSearchRequest( + { + collection_name: 'test', + data: [{ data: [1, 2, 3], anns_field: 'vector' }], + rerank: RRFRanker(), + limit: 10, + cluster_id: 'in07-xxx', + }, + describeCollectionResponse, + milvusProto + ); + + expect(findKeyValue((result.request as any).rank_params, CLUSTER_ID)).toBe( + 'in07-xxx' + ); + }); + it('should build search request with rerank function correctly', () => { // path const milvusProtoPath = path.resolve(