Skip to content

Commit 643eeb9

Browse files
committed
Caikit embeddings examples + local run documentation
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
1 parent c12cb82 commit 643eeb9

5 files changed

Lines changed: 495 additions & 0 deletions

File tree

examples/embeddings/README.md

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Set up and run locally caikit embeddings server
2+
3+
#### Setting Up Virtual Environment using Python venv
4+
5+
For [(venv)](https://docs.python.org/3/library/venv.html), make sure you are in an activated `venv` when running `python` in the example commands that follow. Use `deactivate` if you want to exit the `venv`.
6+
7+
```shell
8+
python3 -m venv venv
9+
source venv/bin/activate
10+
```
11+
12+
### Models
13+
14+
To create a model configuration and artifacts, the best practice is to run the module's bootstrap() and save() methods. This will:
15+
16+
* Load the model by name (from Hugging Face hub or repository) or from a local directory. The model is loaded using the sentence-transformers library.
17+
* Save a config.yml which:
18+
* Ties the model to the module (with a module_id GUID)
19+
* Sets the artifacts_path to the default "artifacts" subdirectory
20+
* Saves the model in the artifacts subdirectory
21+
22+
This can be done by running the `boostrap_model.py` script in your virtual environment.
23+
24+
```shell
25+
source venv/bin/activate
26+
./demo/server/bootstrap_model.py -m <MODEL_NAME_OR_PATH> -o <OUTPUT_DIR>
27+
```
28+
29+
30+
To avoid overwriting your files, the save() will return an error if the output directory already exists. You may want to use a temporary name. After success, move the output directory to a `<model-id>` directory under your local models dir.
31+
32+
33+
### Starting the Caikit Runtime
34+
35+
Run caikit-runtime configured to use the caikit-nlp library. Set up the following environment variables:
36+
37+
```bash
38+
export RUNTIME_HTTP_ENABLED=true
39+
export RUNTIME_LOCAL_MODELS_DIR=/models
40+
export RUNTIME_LAZY_LOAD_LOCAL_MODELS=true
41+
```
42+
43+
In one terminal, start the runtime server:
44+
45+
```bash
46+
source venv/bin/activate
47+
pip install -r requirements.txt
48+
caikit-runtime
49+
```
50+
51+
### Embedding retrieval example Python client
52+
53+
In another terminal, run the example client code to retrieve embeddings.
54+
55+
```shell
56+
source venv/bin/activate
57+
cd demo/client
58+
MODEL=<model-id> python embeddings.py
59+
```
60+
61+
The client code calls the model and queries for embeddings using 2 example sentences.
62+
63+
You should see output similar to the following:
64+
65+
```ShellSession
66+
$ python embeddings.py
67+
INPUT TEXTS: ['test first sentence', 'another test sentence']
68+
OUTPUT: {
69+
{
70+
"results": [
71+
[
72+
-0.17895537614822388,
73+
0.03200146183371544,
74+
-0.030327674001455307,
75+
...
76+
],
77+
[
78+
-0.17895537614822388,
79+
0.03200146183371544,
80+
-0.030327674001455307,
81+
...
82+
]
83+
],
84+
"producerId": {
85+
"name": "EmbeddingModule",
86+
"version": "0.0.1"
87+
},
88+
"inputTokenCount": "9"
89+
}
90+
}
91+
LENGTH: 2 x 384
92+
```
93+
94+
### Sentence similarity example Python client
95+
96+
In another terminal, run the client code to infer sentence similarity.
97+
98+
```shell
99+
source venv/bin/activate
100+
cd demo/client
101+
MODEL=<model-id> python sentence_similarity.py
102+
```
103+
104+
The client code calls the model and queries sentence similarity using 1 source sentence and 2 other sentences (hardcoded in sentence_similarity.py). The result produces the cosine similarity score by comparing the source sentence with each of the other sentences.
105+
106+
You should see output similar to the following:
107+
108+
```ShellSession
109+
$ python sentence_similarity.py
110+
SOURCE SENTENCE: first sentence
111+
SENTENCES: ['test first sentence', 'another test sentence']
112+
OUTPUT: {
113+
"result": {
114+
"scores": [
115+
1.0000001192092896
116+
]
117+
},
118+
"producerId": {
119+
"name": "EmbeddingModule",
120+
"version": "0.0.1"
121+
},
122+
"inputTokenCount": "9"
123+
}
124+
```
125+
126+
### Reranker example Python client
127+
128+
In another terminal, run the client code to execute the reranker task using both gRPC and REST.
129+
130+
```shell
131+
source venv/bin/activate
132+
cd demo/client
133+
MODEL=<model-id> python reranker.py
134+
```
135+
136+
You should see output similar to the following:
137+
138+
```ShellSession
139+
$ python reranker.py
140+
======================
141+
TOP N: 3
142+
QUERIES: ['first sentence', 'any sentence']
143+
DOCUMENTS: [{'text': 'first sentence', 'title': 'first title'}, {'_text': 'another sentence', 'more': 'more attributes here'}, {'text': 'a doc with a nested metadata', 'meta': {'foo': 'bar', 'i': 999, 'f': 12.34}}]
144+
======================
145+
RESPONSE from gRPC:
146+
===
147+
QUERY: first sentence
148+
score: 0.9999997019767761 index: 0 text: first sentence
149+
score: 0.7350112199783325 index: 1 text: another sentence
150+
score: 0.10398174077272415 index: 2 text: a doc with a nested metadata
151+
===
152+
QUERY: any sentence
153+
score: 0.6631797552108765 index: 0 text: first sentence
154+
score: 0.6505964398384094 index: 1 text: another sentence
155+
score: 0.11903437972068787 index: 2 text: a doc with a nested metadata
156+
===================
157+
RESPONSE from HTTP:
158+
{
159+
"results": [
160+
{
161+
"query": "first sentence",
162+
"scores": [
163+
{
164+
"document": {
165+
"text": "first sentence",
166+
"title": "first title"
167+
},
168+
"index": 0,
169+
"score": 0.9999997019767761,
170+
"text": "first sentence"
171+
},
172+
{
173+
"document": {
174+
"_text": "another sentence",
175+
"more": "more attributes here"
176+
},
177+
"index": 1,
178+
"score": 0.7350112199783325,
179+
"text": "another sentence"
180+
},
181+
{
182+
"document": {
183+
"text": "a doc with a nested metadata",
184+
"meta": {
185+
"foo": "bar",
186+
"i": 999,
187+
"f": 12.34
188+
}
189+
},
190+
"index": 2,
191+
"score": 0.10398174077272415,
192+
"text": "a doc with a nested metadata"
193+
}
194+
]
195+
},
196+
{
197+
"query": "any sentence",
198+
"scores": [
199+
{
200+
"document": {
201+
"text": "first sentence",
202+
"title": "first title"
203+
},
204+
"index": 0,
205+
"score": 0.6631797552108765,
206+
"text": "first sentence"
207+
},
208+
{
209+
"document": {
210+
"_text": "another sentence",
211+
"more": "more attributes here"
212+
},
213+
"index": 1,
214+
"score": 0.6505964398384094,
215+
"text": "another sentence"
216+
},
217+
{
218+
"document": {
219+
"text": "a doc with a nested metadata",
220+
"meta": {
221+
"foo": "bar",
222+
"i": 999,
223+
"f": 12.34
224+
}
225+
},
226+
"index": 2,
227+
"score": 0.11903437972068787,
228+
"text": "a doc with a nested metadata"
229+
}
230+
]
231+
}
232+
],
233+
"producerId": {
234+
"name": "EmbeddingModule",
235+
"version": "0.0.1"
236+
},
237+
"inputTokenCount": "9"
238+
}
239+
```

examples/embeddings/embeddings.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright The Caikit Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Third Party
16+
import grpc
17+
from os import path
18+
import sys
19+
import os
20+
21+
# Local
22+
import caikit
23+
from caikit.runtime.service_factory import ServicePackageFactory
24+
25+
# Add the runtime/library to the path
26+
sys.path.append(
27+
path.abspath(path.join(path.dirname(__file__), "../../"))
28+
)
29+
30+
# Load configuration for Caikit runtime
31+
CONFIG_PATH = path.realpath(
32+
path.join(path.dirname(__file__), "config.yml")
33+
)
34+
caikit.configure(CONFIG_PATH)
35+
36+
# NOTE: The model id needs to be a path to folder.
37+
# NOTE: This is relative path to the models directory
38+
MODEL_ID = os.getenv("MODEL", "mini")
39+
40+
inference_service = ServicePackageFactory().get_service_package(
41+
ServicePackageFactory.ServiceType.INFERENCE,
42+
)
43+
44+
port = os.getenv('CAIKIT_EMBEDDINGS_PORT') if os.getenv('CAIKIT_EMBEDDINGS_PORT') else 8085
45+
host = os.getenv('CAIKIT_EMBEDDINGS_HOST') if os.getenv('CAIKIT_EMBEDDINGS_HOST') else 'localhost'
46+
channel = grpc.insecure_channel(f"{host}:{port}")
47+
client_stub = inference_service.stub_class(channel)
48+
49+
# Create request object
50+
51+
texts = ["test first sentence", "another test sentence"]
52+
request = inference_service.messages.EmbeddingTasksRequest(texts=texts)
53+
54+
# Fetch predictions from server (infer)
55+
response = client_stub.EmbeddingTasksPredict(
56+
request, metadata=[("mm-model-id", MODEL_ID)]
57+
)
58+
59+
# Print response
60+
print("INPUTS TEXTS: ", texts)
61+
print("RESULTS: [")
62+
for d in response.results.vectors:
63+
woo = d.WhichOneof("data") # which one of data_<float_type>s did we get?
64+
print(getattr(d, woo).values)
65+
print("]")
66+
print("LENGTH: ", len(response.results.vectors), " x ",
67+
len(getattr(response.results.vectors[0], woo).values))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
caikit[runtime-grpc,runtime-http]
2+
caikit-nlp

0 commit comments

Comments
 (0)