diff --git a/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java b/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java index 399d52087..29081b1eb 100644 --- a/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java +++ b/sdk-core/src/main/java/io/milvus/response/SearchResultsWrapper.java @@ -253,7 +253,7 @@ public long getNumQueries() { return results.getNumQueries(); } - private static final class Position { + public static final class Position { private final long offset; private final long k; @@ -271,7 +271,7 @@ public long getK() { } } - private Position getOffsetByIndex(int indexOfTarget) { + public Position getOffsetByIndex(int indexOfTarget) { List kList = results.getTopksList(); // if the server didn't return separate topK, use same topK value "0" diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java index fc36013fa..e919625a1 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java @@ -23,6 +23,7 @@ import io.milvus.v2.common.IndexParam; import io.milvus.v2.service.collection.request.CreateCollectionReq; import io.milvus.v2.service.vector.request.data.BaseVector; +import io.milvus.v2.service.vector.request.highlighter.Highlighter; import java.util.ArrayList; import java.util.HashMap; @@ -69,6 +70,9 @@ public class SearchReq { // Boolean, Long, Double, String, List, List, List, List private Map filterTemplateValues; + // milvus v2.6.9 supports highlighter for search results + private Highlighter highlighter; + private SearchReq(SearchReqBuilder builder) { this.databaseName = builder.databaseName; this.collectionName = builder.collectionName; @@ -95,6 +99,7 @@ private SearchReq(SearchReqBuilder builder) { this.functionScore = builder.functionScore; this.filterTemplateValues = builder.filterTemplateValues; this.timezone = builder.timezone; + this.highlighter = builder.highlighter; } // Getters and Setters @@ -294,6 +299,10 @@ public void setFilterTemplateValues(Map filterTemplateValues) { this.filterTemplateValues = filterTemplateValues; } + public Highlighter getHighlighter() { + return highlighter; + } + @Override public String toString() { return "SearchReq{" + @@ -319,6 +328,7 @@ public String toString() { ", groupSize=" + groupSize + ", strictGroupSize=" + strictGroupSize + ", ranker=" + ranker + + ", highlighter=" + (highlighter == null ? "null" : (highlighter.highlightType() + ":" + highlighter.getParams())) + ", functionScore=" + functionScore + // ", filterTemplateValues=" + filterTemplateValues + '}'; @@ -354,6 +364,7 @@ public static class SearchReqBuilder { private CreateCollectionReq.Function ranker; private FunctionScore functionScore; private Map filterTemplateValues = new HashMap<>(); // default value + private Highlighter highlighter; private SearchReqBuilder() { } @@ -487,6 +498,11 @@ public SearchReqBuilder filterTemplateValues(Map filterTemplateV return this; } + public SearchReqBuilder highlighter(Highlighter highlighter) { + this.highlighter = highlighter; + return this; + } + public SearchReq build() { return new SearchReq(this); } diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/highlighter/Highlighter.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/highlighter/Highlighter.java new file mode 100644 index 000000000..5e3869050 --- /dev/null +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/highlighter/Highlighter.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.v2.service.vector.request.highlighter; + +import java.util.Map; + +public interface Highlighter { + String highlightType(); + + Map getParams(); +} diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/highlighter/LexicalHighlighter.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/highlighter/LexicalHighlighter.java new file mode 100644 index 000000000..78bd2a915 --- /dev/null +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/highlighter/LexicalHighlighter.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.v2.service.vector.request.highlighter; + +import io.milvus.common.utils.JsonUtils; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class LexicalHighlighter implements Highlighter { + private final List highlightQueries; + private final Boolean highlightSearchText; + private final List preTags; + private final List postTags; + private final Integer fragmentOffset; + private final Integer fragmentSize; + private final Integer numOfFragments; + + public LexicalHighlighter(LexicalHighlighterBuilder builder) { + this.highlightQueries = builder.highlightQueries; + this.highlightSearchText = builder.highlightSearchText; + this.preTags = builder.preTags; + this.postTags = builder.postTags; + this.fragmentOffset = builder.fragmentOffset; + this.fragmentSize = builder.fragmentSize; + this.numOfFragments = builder.numOfFragments; + } + + @Override + public String highlightType() { + return "Lexical"; + } + + @Override + public Map getParams() { + Map params = new java.util.HashMap<>(); + if (this.highlightQueries != null) { + // serialize the list of HighlightQuery to a JSON array string using Gson + params.put("highlight_queries", JsonUtils.toJson(this.highlightQueries)); + } + if (this.highlightSearchText != null) { + params.put("highlight_search_text", this.highlightSearchText.toString()); + } + if (this.preTags != null) { + params.put("pre_tags", JsonUtils.toJson(this.preTags)); + } + if (this.postTags != null) { + params.put("post_tags", JsonUtils.toJson(this.postTags)); + } + if (this.fragmentOffset != null) { + params.put("fragment_offset", this.fragmentOffset.toString()); + } + if (this.fragmentSize != null) { + params.put("fragment_size", this.fragmentSize.toString()); + } + if (this.numOfFragments != null) { + params.put("num_of_fragments", this.numOfFragments.toString()); + } + return params; + } + + public static class HighlightQuery { + public String type; + public String field; + public String text; + + public HighlightQuery(String type, String field, String query) { + this.type = type; + this.field = field; + this.text = query; + } + + @Override + public String toString() { + return JsonUtils.toJson(this); + } + } + + public static class LexicalHighlighterBuilder { + private List highlightQueries; + private Boolean highlightSearchText; + private List preTags; + private List postTags; + private Integer fragmentOffset; + private Integer fragmentSize; + private Integer numOfFragments; + + public LexicalHighlighterBuilder() { + } + + public LexicalHighlighterBuilder highlightQueries(List queries) { + this.highlightQueries = queries; + return this; + } + + public LexicalHighlighterBuilder addHighlightQuery(HighlightQuery q) { + if (this.highlightQueries == null) this.highlightQueries = new ArrayList<>(); + this.highlightQueries.add(q); + return this; + } + + public LexicalHighlighterBuilder highlightSearchText(Boolean highlightSearchText) { + this.highlightSearchText = highlightSearchText; + return this; + } + + public LexicalHighlighterBuilder preTags(List preTags) { + this.preTags = preTags; + return this; + } + + public LexicalHighlighterBuilder addPreTag(String tag) { + if (this.preTags == null) this.preTags = new ArrayList<>(); + this.preTags.add(tag); + return this; + } + + public LexicalHighlighterBuilder postTags(List postTags) { + this.postTags = postTags; + return this; + } + + public LexicalHighlighterBuilder addPostTag(String tag) { + if (this.postTags == null) this.postTags = new ArrayList<>(); + this.postTags.add(tag); + return this; + } + + public LexicalHighlighterBuilder fragmentOffset(Integer offset) { + this.fragmentOffset = offset; + return this; + } + + public LexicalHighlighterBuilder fragmentSize(Integer size) { + this.fragmentSize = size; + return this; + } + + public LexicalHighlighterBuilder numOfFragments(Integer num) { + this.numOfFragments = num; + return this; + } + + public LexicalHighlighter build() { + return new LexicalHighlighter(this); + } + } + + public static LexicalHighlighterBuilder builder() { + return new LexicalHighlighterBuilder(); + } +} diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/request/highlighter/SemanticHighlighter.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/highlighter/SemanticHighlighter.java new file mode 100644 index 000000000..b50f1da12 --- /dev/null +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/request/highlighter/SemanticHighlighter.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.v2.service.vector.request.highlighter; + +import io.milvus.common.utils.JsonUtils; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class SemanticHighlighter implements Highlighter { + private final List queries; + private final List inputFields; + private final List preTags; + private final List postTags; + private final Float threshold; + private final Boolean highlightOnly; + private final String modelDeploymentID; + private final Integer maxClientBatchSize; + + public SemanticHighlighter(SemanticHighlighterBuilder builder) { + this.queries = builder.queries; + this.inputFields = builder.inputFields; + this.preTags = builder.preTags; + this.postTags = builder.postTags; + this.threshold = builder.threshold; + this.highlightOnly = builder.highlightOnly; + this.modelDeploymentID = builder.modelDeploymentID; + this.maxClientBatchSize = builder.maxClientBatchSize; + } + + @Override + public String highlightType() { + return "Semantic"; + } + + @Override + public Map getParams() { + Map params = new java.util.HashMap<>(); + if (this.queries != null) { + params.put("queries", JsonUtils.toJson(this.queries)); + } + if (this.inputFields != null) { + params.put("input_fields", JsonUtils.toJson(this.inputFields)); + } + if (this.preTags != null) { + params.put("pre_tags", JsonUtils.toJson(this.preTags)); + } + if (this.postTags != null) { + params.put("post_tags", JsonUtils.toJson(this.postTags)); + } + if (this.threshold != null) { + params.put("threshold", this.threshold.toString()); + } + if (this.highlightOnly != null) { + params.put("highlight_only", this.highlightOnly.toString()); + } + if (this.modelDeploymentID != null) { + params.put("model_deployment_id", this.modelDeploymentID); + } + if (this.maxClientBatchSize != null) { + params.put("max_client_batch_size", this.maxClientBatchSize.toString()); + } + return params; + } + + public static class SemanticHighlighterBuilder { + private List queries; + private List inputFields; + private List preTags; + private List postTags; + private Float threshold; + private Boolean highlightOnly; + private String modelDeploymentID; + private Integer maxClientBatchSize; + + public SemanticHighlighterBuilder() { + } + + public SemanticHighlighterBuilder queries(List queries) { + this.queries = queries; + return this; + } + + public SemanticHighlighterBuilder addQuery(String q) { + if (this.queries == null) this.queries = new ArrayList<>(); + this.queries.add(q); + return this; + } + + public SemanticHighlighterBuilder inputFields(List inputFields) { + this.inputFields = inputFields; + return this; + } + + public SemanticHighlighterBuilder addInputField(String f) { + if (this.inputFields == null) this.inputFields = new ArrayList<>(); + this.inputFields.add(f); + return this; + } + + public SemanticHighlighterBuilder preTags(List preTags) { + this.preTags = preTags; + return this; + } + + public SemanticHighlighterBuilder addPreTag(String tag) { + if (this.preTags == null) this.preTags = new ArrayList<>(); + this.preTags.add(tag); + return this; + } + + public SemanticHighlighterBuilder postTags(List postTags) { + this.postTags = postTags; + return this; + } + + public SemanticHighlighterBuilder addPostTag(String tag) { + if (this.postTags == null) this.postTags = new ArrayList<>(); + this.postTags.add(tag); + return this; + } + + public SemanticHighlighterBuilder threshold(Float threshold) { + this.threshold = threshold; + return this; + } + + public SemanticHighlighterBuilder highlightOnly(Boolean highlightOnly) { + this.highlightOnly = highlightOnly; + return this; + } + + public SemanticHighlighterBuilder modelDeploymentID(String modelDeploymentID) { + this.modelDeploymentID = modelDeploymentID; + return this; + } + + public SemanticHighlighterBuilder maxClientBatchSize(Integer size) { + this.maxClientBatchSize = size; + return this; + } + + public SemanticHighlighter build() { + return new SemanticHighlighter(this); + } + } + + public static SemanticHighlighterBuilder builder() { + return new SemanticHighlighterBuilder(); + } +} diff --git a/sdk-core/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java b/sdk-core/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java index d8bf776bd..d6fb1762a 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/vector/response/SearchResp.java @@ -19,10 +19,10 @@ package io.milvus.v2.service.vector.response; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.collections4.MapUtils; + +import java.util.*; public class SearchResp { private List> searchResults; @@ -102,12 +102,14 @@ public static class SearchResult { private Float score; private Object id; private String primaryKey; + private Map highlightResults; private SearchResult(SearchResultBuilder builder) { this.entity = builder.entity; this.score = builder.score; this.id = builder.id; this.primaryKey = builder.primaryKey; + this.highlightResults = builder.highlightResults == null ? new HashMap<>() : builder.highlightResults; } public static SearchResultBuilder builder() { @@ -146,9 +148,23 @@ public void setPrimaryKey(String primaryKey) { this.primaryKey = primaryKey; } + public Map getHighlightResults() { + return highlightResults; + } + + public HighlightResult getHighlightResult(String fieldName) { + return this.highlightResults.get(fieldName); + } + + public void addHighlightResult(String fieldName, HighlightResult highlightResult) { + if (this.highlightResults == null) this.highlightResults = new HashMap<>(); + this.highlightResults.put(fieldName, highlightResult); + } + @Override public String toString() { - return "{" + getPrimaryKey() + ": " + getId() + ", Score: " + getScore() + ", OutputFields: " + entity + "}"; + return "{" + getPrimaryKey() + ": " + getId() + ", Score: " + getScore() + ", OutputFields: " + entity + + (MapUtils.isEmpty(highlightResults) ? "" : (", HighlightResults: " + highlightResults)) + "}"; } public static class SearchResultBuilder { @@ -156,6 +172,7 @@ public static class SearchResultBuilder { private Float score; private Object id; private String primaryKey = "id"; + private Map highlightResults = new HashMap<>(); public SearchResultBuilder entity(Map entity) { this.entity = entity; @@ -177,9 +194,96 @@ public SearchResultBuilder primaryKey(String primaryKey) { return this; } + public SearchResultBuilder highlightResults(Map highlightResults) { + this.highlightResults = highlightResults; + return this; + } + + public SearchResultBuilder addHighlightResult(String fieldName, HighlightResult highlightResult) { + if (this.highlightResults == null) this.highlightResults = new HashMap<>(); + this.highlightResults.put(fieldName, highlightResult); + return this; + } + public SearchResult build() { return new SearchResult(this); } } } + + public static class HighlightResult { + private final String fieldName; + private final List fragments; + private final List scores; + + private HighlightResult(HighlightResultBuilder builder) { + this.fieldName = builder.fieldName; + this.fragments = builder.fragments; + this.scores = builder.scores; + } + + public static HighlightResultBuilder builder() { + return new HighlightResultBuilder(); + } + + public String getFieldName() { + return fieldName; + } + + public List getFragments() { + return fragments; + } + + public List getScores() { + return scores; + } + + @Override + public String toString() { + return "HighlightResult{" + + "fieldName='" + fieldName + '\'' + + ", fragments=" + fragments + + ", scores=" + scores + + '}'; + } + + public static class HighlightResultBuilder { + private String fieldName = ""; + private List fragments = new ArrayList<>(); + private List scores = new ArrayList<>(); + + public HighlightResultBuilder fieldName(String fieldName) { + this.fieldName = fieldName; + return this; + } + + public HighlightResultBuilder fragments(List fragments) { + this.fragments = fragments; + return this; + } + + public HighlightResultBuilder addFragment(String fragment) { + if (this.fragments == null) this.fragments = new ArrayList<>(); + this.fragments.add(fragment); + return this; + } + + public HighlightResultBuilder scores(List scores) { + this.scores = scores; + return this; + } + + public HighlightResultBuilder addScore(Float score) { + if (this.scores == null) this.scores = new ArrayList<>(); + this.scores.add(score); + return this; + } + + public HighlightResult build() { + return new HighlightResult(this); + } + } + + } + } diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/ConvertUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/ConvertUtils.java index 2e63ed629..b98ef5bb6 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/ConvertUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/ConvertUtils.java @@ -96,12 +96,38 @@ public List> getEntities(SearchResults response) { long numQueries = response.getResults().getNumQueries(); List> searchResults = new ArrayList<>(); for (int i = 0; i < numQueries; i++) { - searchResults.add(searchResultsWrapper.getIDScore(i).stream().map(idScore -> SearchResp.SearchResult.builder() - .entity(idScore.getFieldValues()) - .score(idScore.getScore()) - .primaryKey(idScore.getPrimaryKey()) - .id(idScore.getStrID().isEmpty() ? idScore.getLongID() : idScore.getStrID()) - .build()).collect(Collectors.toList())); + List singleResults = new ArrayList<>(); + for (SearchResultsWrapper.IDScore idScore : searchResultsWrapper.getIDScore(i)) { + singleResults.add(SearchResp.SearchResult.builder() + .entity(idScore.getFieldValues()) + .score(idScore.getScore()) + .primaryKey(idScore.getPrimaryKey()) + .id(idScore.getStrID().isEmpty() ? idScore.getLongID() : idScore.getStrID()) + .build()); + } + + // set highlight + SearchResultsWrapper.Position position = searchResultsWrapper.getOffsetByIndex(i); + long offset = position.getOffset(); + long k = position.getK(); + List highlightResults = response.getResults().getHighlightResultsList(); + for (HighlightResult highlightResult : highlightResults) { + String fieldName = highlightResult.getFieldName(); + List highlightDatas = highlightResult.getDatasList(); + for (long j = 0; j < k; j++) { + HighlightData highlightData = highlightDatas.get((int) (offset + j)); + List fragments = highlightData.getFragmentsList(); + List scores = highlightData.getScoresList(); + SearchResp.HighlightResult highlightResultObj = SearchResp.HighlightResult.builder() + .fieldName(fieldName) + .fragments(fragments) + .scores(scores) + .build(); + singleResults.get((int) j).addHighlightResult(fieldName, highlightResultObj); + } + } + + searchResults.add(singleResults); } return searchResults; } diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java index babf79973..af2b41910 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java @@ -37,6 +37,7 @@ import io.milvus.v2.service.vector.request.*; import io.milvus.v2.service.vector.request.FunctionScore; import io.milvus.v2.service.vector.request.data.BaseVector; +import io.milvus.v2.service.vector.request.highlighter.Highlighter; import io.milvus.v2.service.vector.request.ranker.RRFRanker; import io.milvus.v2.service.vector.request.ranker.WeightedRanker; import org.apache.commons.collections4.CollectionUtils; @@ -369,6 +370,22 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) { builder.setFunctionScore(convertOneFunction(ranker)); } + // set highlighter + Highlighter highlighter = request.getHighlighter(); + if (highlighter != null) { + io.milvus.grpc.Highlighter.Builder hlBuilder = io.milvus.grpc.Highlighter.newBuilder() + .setType(HighlightType.valueOf(highlighter.highlightType())); + Map hlParams = highlighter.getParams(); + hlParams.forEach((key, value) -> { + hlBuilder.addParams( + KeyValuePair.newBuilder() + .setKey(key) + .setValue(value) + .build()); + }); + builder.setHighlighter(hlBuilder.build()); + } + return builder.build(); } diff --git a/sdk-core/src/main/milvus-proto b/sdk-core/src/main/milvus-proto index 5b5ad7223..fd9875a85 160000 --- a/sdk-core/src/main/milvus-proto +++ b/sdk-core/src/main/milvus-proto @@ -1 +1 @@ -Subproject commit 5b5ad7223d65baa6eb3a4bd075969027b10cadf8 +Subproject commit fd9875a85e8e5cdb38d1de8856a1e62fdc85f0fa