Skip to content

Commit b8fb8ea

Browse files
authored
Add isTrainable field to plugin; align with new simdex code (#33)
1 parent 639a8e0 commit b8fb8ea

13 files changed

Lines changed: 211 additions & 181 deletions

File tree

src/steamship/client/client.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
from steamship.client.operations.tagger import TagRequest
88
from steamship.client.tasks import Tasks
99
from steamship.data import File
10-
from steamship.data.embeddings import EmbedAndSearchRequest, EmbedAndSearchResponse, EmbeddingIndex
10+
from steamship.data.embeddings import EmbedAndSearchRequest, QueryResults, EmbeddingIndex
11+
from steamship.data.search import Hit
1112
from steamship.data.space import Space
1213

1314
__copyright__ = "Steamship"
1415
__license__ = "MIT"
1516

1617
from steamship.extension.file import TagResponse
18+
from steamship.plugin.outputs.block_and_tag_plugin_output import BlockAndTagPluginOutput
1719
from steamship.plugin.outputs.embedded_items_plugin_output import EmbeddedItemsPluginOutput
1820

1921
_logger = logging.getLogger(__name__)
@@ -128,27 +130,6 @@ def scrape(
128130
space=space
129131
)
130132

131-
def embed(
132-
self,
133-
docs: List[str],
134-
pluginInstance: str,
135-
spaceId: str = None,
136-
spaceHandle: str = None,
137-
space: Space = None
138-
) -> Response[EmbeddedItemsPluginOutput]:
139-
req = EmbedRequest(
140-
docs=docs,
141-
pluginInstance=pluginInstance
142-
)
143-
return self.post(
144-
'embedding/create',
145-
req,
146-
expect=EmbeddedItemsPluginOutput,
147-
spaceId=spaceId,
148-
spaceHandle=spaceHandle,
149-
space=space
150-
)
151-
152133
def embed_and_search(
153134
self,
154135
query: str,
@@ -158,17 +139,17 @@ def embed_and_search(
158139
spaceId: str = None,
159140
spaceHandle: str = None,
160141
space: Space = None
161-
) -> Response[EmbedAndSearchResponse]:
142+
) -> Response[QueryResults]:
162143
req = EmbedAndSearchRequest(
163144
query=query,
164145
docs=docs,
165146
pluginInstance=pluginInstance,
166147
k=k
167148
)
168149
return self.post(
169-
'embedding/search',
150+
'plugin/instance/embeddingSearch',
170151
req,
171-
expect=EmbedAndSearchResponse,
152+
expect=QueryResults,
172153
spaceId=spaceId,
173154
spaceHandle=spaceHandle,
174155
space=space

src/steamship/data/embeddings.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from dataclasses import dataclass
3-
from typing import List, Dict, Union
3+
from typing import List, Dict, Union, TypeVar, Generic
44

55
from steamship.base import Client, Request, Response, metadata_to_str
66
from steamship.data.search import Hit
@@ -13,18 +13,40 @@ class EmbedAndSearchRequest(Request):
1313
pluginInstance: str
1414
k: int = 1
1515

16+
17+
#TODO: These types are not generics like the Swift QueryResult/QueryResults.
18+
@dataclass
19+
class QueryResult():
20+
value: Hit
21+
score: float
22+
index: int
23+
id: str
24+
25+
@staticmethod
26+
def from_dict(d: any, client: Client = None) -> "QueryResult":
27+
value = Hit.from_dict(d.get("value", {}))
28+
return QueryResult(
29+
value = value,
30+
score = d.get('score'),
31+
index = d.get('index'),
32+
id = d.get('id')
33+
)
34+
1635
@dataclass
17-
class EmbedAndSearchResponse(Request):
18-
hits: List[Hit] = None
36+
class QueryResults(Request):
37+
items: List[QueryResult] = None
1938

2039
@staticmethod
21-
def from_dict(d: any, client: Client = None) -> "EmbedAndSearchResponse":
22-
hits = [Hit.from_dict(h) for h in (d.get("hits", []) or [])]
23-
return EmbedAndSearchResponse(
24-
hits=hits
40+
def from_dict(d: any, client: Client = None) -> "QueryResults":
41+
items = [QueryResult.from_dict(h) for h in (d.get("items", []) or [])]
42+
return QueryResults(
43+
items=items
2544
)
2645

2746

47+
48+
49+
2850
@dataclass
2951
class EmbeddedItem:
3052
id: str = None
@@ -145,18 +167,6 @@ class IndexSearchRequest(Request):
145167
includeMetadata: bool = False
146168

147169

148-
@dataclass
149-
class IndexSearchResponse:
150-
hits: List[Hit] = None
151-
152-
@staticmethod
153-
def from_dict(d: any, client: Client = None) -> "IndexSearchResponse":
154-
hits = [Hit.from_dict(h) for h in (d.get("hits", []) or [])]
155-
return IndexSearchResponse(
156-
hits=hits
157-
)
158-
159-
160170
@dataclass
161171
class IndexSnapshotRequest(Request):
162172
indexId: str
@@ -467,7 +477,7 @@ def search(
467477
spaceId: str = None,
468478
spaceHandle: str = None,
469479
space: any = None
470-
) -> Response[IndexSearchResponse]:
480+
) -> Response[QueryResults]:
471481
if type(query) == list:
472482
req = IndexSearchRequest(
473483
self.id,
@@ -487,7 +497,7 @@ def search(
487497
ret = self.client.post(
488498
'embedding-index/search',
489499
req,
490-
expect=IndexSearchResponse,
500+
expect=QueryResults,
491501
spaceId=spaceId,
492502
spaceHandle=spaceHandle,
493503
space=space

src/steamship/data/plugin.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class Plugin:
2020

2121
@dataclass
2222
class CreatePluginRequest(Request):
23+
isTrainable: bool
2324
id: str = None
2425
name: str = None
2526
type: str = None
@@ -69,9 +70,9 @@ class GetPluginRequest(Request):
6970

7071

7172
class PluginType:
72-
embedder = "embedder"
7373
parser = "parser"
7474
classifier = "classifier"
75+
tagger = "tagger"
7576

7677

7778
class PluginAdapterType:
@@ -133,6 +134,7 @@ def from_dict(d: any, client: Client = None) -> "Plugin":
133134
@staticmethod
134135
def create(
135136
client: Client,
137+
isTrainable: bool,
136138
name: str,
137139
description: str,
138140
type: str,
@@ -152,6 +154,7 @@ def create(
152154
metadata = json.dumps(metadata)
153155

154156
req = CreatePluginRequest(
157+
isTrainable=isTrainable,
155158
name=name,
156159
type=type,
157160
transport=transport,

src/steamship/data/tags/text_tag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ class TextTag:
3030
isOov = "isOov"
3131
isStop = "isStop"
3232
lang = "lang"
33+
embedding = "embedding"

tests/client/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,10 @@ def deploy_app(py_name: str, versionConfigTemplate : Dict[str, any] = None, inst
168168

169169

170170
@contextlib.contextmanager
171-
def deploy_plugin(py_name: str, plugin_type: str, versionConfigTemplate : Dict[str, any] = None, instanceConfig : Dict[str, any] = None):
171+
def deploy_plugin(py_name: str, plugin_type: str, versionConfigTemplate : Dict[str, any] = None, instanceConfig : Dict[str, any] = None, isTrainable: bool = False):
172172
client = _steamship()
173173
name = _random_name()
174-
plugin = Plugin.create(client, name=name, description='test', type=plugin_type, transport="jsonOverHttp",
174+
plugin = Plugin.create(client, isTrainable=isTrainable, name=name, description='test', type=plugin_type, transport="jsonOverHttp",
175175
isPublic=False)
176176
assert (plugin.error is None)
177177
assert (plugin.data is not None)

tests/client/operations/test_embed.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from steamship import PluginInstance
1+
from steamship import PluginInstance, File
22
from steamship.base import Client
33

44
from tests.client.helpers import _steamship
@@ -8,19 +8,31 @@
88

99
_TEST_EMBEDDER = "test-embedder"
1010

11+
def count_embeddings(file: File):
12+
embeddings = 0
13+
for block in file.blocks:
14+
for tag in block.tags:
15+
if tag.kind == 'text' and tag.name == 'embedding':
16+
embeddings += 1
17+
return embeddings
1118

1219
def basic_embeddings(steamship: Client, pluginInstance: str):
13-
e1 = steamship.embed(["This is a test"], pluginInstance=pluginInstance)
14-
e1b = steamship.embed(["Banana"], pluginInstance=pluginInstance)
15-
assert (len(e1.data.embeddings) == 1)
16-
assert (len(e1.data.embeddings[0]) > 1)
20+
e1 = steamship.tag("This is a test", pluginInstance=pluginInstance)
21+
e1b = steamship.tag("Banana", pluginInstance=pluginInstance)
22+
e1.wait()
23+
e1b.wait()
24+
assert (count_embeddings(e1.data.file) == 1)
25+
assert (count_embeddings(e1b.data.file) == 1)
26+
assert (len(e1.data.file.blocks[0].tags[0].value['embedding']) > 1)
1727

18-
e2 = steamship.embed(["This is a test"], pluginInstance=pluginInstance)
19-
assert (len(e2.data.embeddings) == 1)
20-
assert (len(e2.data.embeddings[0]) == len(e1.data.embeddings[0]))
28+
e2 = steamship.tag("This is a test", pluginInstance=pluginInstance)
29+
e2.wait()
30+
assert (count_embeddings(e2.data.file) == 1)
31+
assert (len(e2.data.file.blocks[0].tags[0].value['embedding']) == len(e1.data.file.blocks[0].tags[0].value['embedding']))
2132

22-
e4 = steamship.embed(["This is a test"], pluginInstance=pluginInstance)
23-
assert (len(e4.data.embeddings) == 1)
33+
e4 = steamship.tag("This is a test", pluginInstance=pluginInstance)
34+
e4.wait()
35+
assert (count_embeddings(e4.data.file) == 1)
2436

2537

2638
def test_basic_embeddings():
@@ -38,8 +50,8 @@ def basic_embedding_search(steamship: Client, pluginInstance: str):
3850
]
3951
query = "Who should I talk to about new employee setup?"
4052
results = steamship.embed_and_search(query, docs, pluginInstance=pluginInstance)
41-
assert (len(results.data.hits) == 1)
42-
assert (results.data.hits[0].value == "Jonathan can help you with new employee onboarding")
53+
assert (len(results.data.items) == 1)
54+
assert (results.data.items[0].value.value == "Jonathan can help you with new employee onboarding")
4355

4456

4557
def test_basic_embedding_search():

tests/client/operations/test_embed_file.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def test_file_parse():
6262
embedResp.wait()
6363

6464
res = index.search("What color are roses?").data
65-
assert (len(res.hits) == 1)
65+
assert (len(res.items) == 1)
6666
# Because the simdex now indexes entire blocks and not sentences, the result of this is the whole block text
67-
assert (res.hits[0].value == " ".join([P1_1, P1_2]))
67+
assert (res.items[0].value.value == " ".join([P1_1, P1_2]))
6868

6969
a.delete()
7070

@@ -115,14 +115,14 @@ def test_file_index():
115115
index = a.index(pluginInstance=embedder.handle)
116116

117117
res = index.search("What color are roses?").data
118-
assert (len(res.hits) == 1)
118+
assert (len(res.items) == 1)
119119
# Because the simdex now indexes entire blocks and not sentences, the result of this is the whole block text
120-
assert (res.hits[0].value == " ".join([P1_1, P1_2]))
120+
assert (res.items[0].value.value == " ".join([P1_1, P1_2]))
121121

122122
res = index.search("What flavors does cake come in?").data
123-
assert (len(res.hits) == 1)
123+
assert (len(res.items) == 1)
124124
# Because the simdex now indexes entire blocks and not sentences, the result of this is the whole block text
125-
assert (res.hits[0].value == " ".join([P4_1, P4_2]))
125+
assert (res.items[0].value.value == " ".join([P4_1, P4_2]))
126126

127127
index.delete()
128128
a.delete()
@@ -172,12 +172,12 @@ def test_file_embed_lookup():
172172
index.insert_file(b.id, blockType='sentence', reindex=True)
173173

174174
res = index.search("What does Ted like to do?").data
175-
assert (len(res.hits) == 1)
176-
assert (res.hits[0].value == content_a)
175+
assert (len(res.items) == 1)
176+
assert (res.items[0].value.value == content_a)
177177

178178
res = index.search("What does Grace like to do?").data
179-
assert (len(res.hits) == 1)
180-
assert (res.hits[0].value == content_b)
179+
assert (len(res.items) == 1)
180+
assert (res.items[0].value.value == content_b)
181181

182182
# Now we list the items
183183
itemsa = index.list_items(fileId=a.id).data

0 commit comments

Comments
 (0)