Skip to content

Commit ee23841

Browse files
Format and remove unnecessary imports
1 parent a133124 commit ee23841

4 files changed

Lines changed: 83 additions & 64 deletions

File tree

spacy_llm/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .bedrock import titan_express, titan_lite
12
from .hf import dolly_hf, openllama_hf, stablelm_hf
23
from .langchain import query_langchain
34
from .rest import anthropic, cohere, noop, openai, palm
@@ -12,4 +13,6 @@
1213
"openllama_hf",
1314
"palm",
1415
"query_langchain",
16+
"titan_lite",
17+
"titan_express",
1518
]

spacy_llm/models/bedrock/model.py

Lines changed: 54 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,108 @@
1-
import os
21
import json
2+
import os
33
import warnings
44
from enum import Enum
5-
from requests import HTTPError
6-
from typing import Any, Dict, Iterable, Optional, Type, List, Sized, Tuple
7-
8-
from confection import SimpleFrozenDict
5+
from typing import Any, Dict, Iterable, List, Optional
96

10-
from ...registry import registry
11-
12-
try:
13-
import boto3
14-
import botocore
15-
from botocore.config import Config
16-
except ImportError as err:
17-
print("To use Bedrock, you need to install boto3. Use `pip install boto3` ")
18-
raise err
197

208
class Models(str, Enum):
219
# Completion models
2210
TITAN_EXPRESS = "amazon.titan-text-express-v1"
2311
TITAN_LITE = "amazon.titan-text-lite-v1"
2412

25-
class Bedrock():
13+
14+
class Bedrock:
2615
def __init__(
27-
self,
28-
model_id: str,
29-
region: str,
30-
config: Dict[Any, Any],
31-
max_retries: int = 5
16+
self, model_id: str, region: str, config: Dict[Any, Any], max_retries: int = 5
3217
):
33-
3418
self._region = region
3519
self._model_id = model_id
3620
self._config = config
3721
self._max_retries = max_retries
38-
39-
# @property
40-
def get_session(self) -> Dict[str, str]:
22+
23+
def get_session_kwargs(self) -> Dict[str, Optional[str]]:
4124

4225
# Fetch and check the credentials
43-
profile = os.getenv("AWS_PROFILE") if not None else ""
26+
profile = os.getenv("AWS_PROFILE") if not None else ""
4427
secret_key_id = os.getenv("AWS_ACCESS_KEY_ID")
4528
secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
4629
session_token = os.getenv("AWS_SESSION_TOKEN")
4730

4831
if profile is None:
4932
warnings.warn(
5033
"Could not find the AWS_PROFILE to access the Amazon Bedrock . Ensure you have an AWS_PROFILE "
51-
"set up by making it available as an environment variable 'AWS_PROFILE'."
52-
)
34+
"set up by making it available as an environment variable AWS_PROFILE."
35+
)
5336

5437
if secret_key_id is None:
5538
warnings.warn(
5639
"Could not find the AWS_ACCESS_KEY_ID to access the Amazon Bedrock . Ensure you have an AWS_ACCESS_KEY_ID "
57-
"set up by making it available as an environment variable 'AWS_ACCESS_KEY_ID'."
40+
"set up by making it available as an environment variable AWS_ACCESS_KEY_ID."
5841
)
42+
5943
if secret_access_key is None:
6044
warnings.warn(
6145
"Could not find the AWS_SECRET_ACCESS_KEY to access the Amazon Bedrock . Ensure you have an AWS_SECRET_ACCESS_KEY "
62-
"set up by making it available as an environment variable 'AWS_SECRET_ACCESS_KEY'."
46+
"set up by making it available as an environment variable AWS_SECRET_ACCESS_KEY."
6347
)
48+
6449
if session_token is None:
6550
warnings.warn(
6651
"Could not find the AWS_SESSION_TOKEN to access the Amazon Bedrock . Ensure you have an AWS_SESSION_TOKEN "
67-
"set up by making it available as an environment variable 'AWS_SESSION_TOKEN'."
52+
"set up by making it available as an environment variable AWS_SESSION_TOKEN."
6853
)
6954

7055
assert secret_key_id is not None
7156
assert secret_access_key is not None
7257
assert session_token is not None
73-
74-
session_kwargs = {"profile_name":profile, "region_name":self._region, "aws_access_key_id":secret_key_id, "aws_secret_access_key":secret_access_key, "aws_session_token":session_token}
75-
bedrock = boto3.Session(**session_kwargs)
76-
return bedrock
7758

78-
def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
59+
session_kwargs = {
60+
"profile_name": profile,
61+
"region_name": self._region,
62+
"aws_access_key_id": secret_key_id,
63+
"aws_secret_access_key": secret_access_key,
64+
"aws_session_token": session_token,
65+
}
66+
return session_kwargs
67+
68+
def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
7969
api_responses: List[str] = []
8070
prompts = list(prompts)
81-
api_config = Config(retries = dict(max_attempts = self._max_retries))
8271

83-
def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
84-
session = self.get_session()
85-
print("Session:", session)
72+
def _request(json_data: str) -> str:
73+
try:
74+
import boto3
75+
except ImportError as err:
76+
warnings.warn(
77+
"To use Bedrock, you need to install boto3. Use pip install boto3 "
78+
)
79+
raise err
80+
from botocore.config import Config
81+
82+
session_kwargs = self.get_session_kwargs()
83+
session = boto3.Session(**session_kwargs)
84+
api_config = Config(retries=dict(max_attempts=self._max_retries))
8685
bedrock = session.client(service_name="bedrock-runtime", config=api_config)
87-
accept = 'application/json'
88-
contentType = 'application/json'
89-
r = bedrock.invoke_model(body=json_data, modelId=self._model_id, accept=accept, contentType=contentType)
90-
responses = json.loads(r['body'].read().decode())['results'][0]['outputText']
86+
accept = "application/json"
87+
contentType = "application/json"
88+
r = bedrock.invoke_model(
89+
body=json_data,
90+
modelId=self._model_id,
91+
accept=accept,
92+
contentType=contentType,
93+
)
94+
responses = json.loads(r["body"].read().decode())["results"][0][
95+
"outputText"
96+
]
9197
return responses
9298

9399
for prompt in prompts:
94100
if self._model_id in [Models.TITAN_LITE, Models.TITAN_EXPRESS]:
95-
responses = _request(json.dumps({"inputText": prompt, "textGenerationConfig":self._config}))
96-
if "error" in responses:
97-
return responses["error"]
101+
responses = _request(
102+
json.dumps(
103+
{"inputText": prompt, "textGenerationConfig": self._config}
104+
)
105+
)
98106

99107
api_responses.append(responses)
100108

spacy_llm/models/bedrock/registry.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,50 @@
1-
from typing import Any, Callable, Dict, Iterable
1+
from typing import Any, Callable, Dict, Iterable, List
22

33
from confection import SimpleFrozenDict
44

55
from ...registry import registry
66
from .model import Bedrock, Models
77

8-
_DEFAULT_RETRIES = 5
9-
_DEFAULT_TEMPERATURE = 0.0
10-
_DEFAULT_MAX_TOKEN_COUNT = 512
11-
_DEFAULT_TOP_P = 1
12-
_DEFAULT_STOP_SEQUENCES = []
8+
_DEFAULT_RETRIES: int = 5
9+
_DEFAULT_TEMPERATURE: float = 0.0
10+
_DEFAULT_MAX_TOKEN_COUNT: int = 512
11+
_DEFAULT_TOP_P: int = 1
12+
_DEFAULT_STOP_SEQUENCES: List[str] = []
13+
1314

1415
@registry.llm_models("spacy.Bedrock.Titan.Express.v1")
1516
def titan_express(
1617
region: str,
1718
model_id: Models = Models.TITAN_EXPRESS,
18-
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE, maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, stopSequences=_DEFAULT_STOP_SEQUENCES, topP =_DEFAULT_TOP_P),
19-
max_retries: int = _DEFAULT_RETRIES
19+
config: Dict[Any, Any] = SimpleFrozenDict(
20+
temperature=_DEFAULT_TEMPERATURE,
21+
maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT,
22+
stopSequences=_DEFAULT_STOP_SEQUENCES,
23+
topP=_DEFAULT_TOP_P,
24+
),
25+
max_retries: int = _DEFAULT_RETRIES,
2026
) -> Callable[[Iterable[str]], Iterable[str]]:
2127
"""Returns Bedrock instance for 'amazon-titan-express' model using boto3 to prompt API.
2228
model_id (ModelId): ID of the deployed model (titan-express)
2329
region (str): Specify the AWS region for the service
2430
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
2531
"""
2632
return Bedrock(
27-
model_id = model_id,
28-
region = region,
29-
config=config,
30-
max_retries=max_retries
33+
model_id=model_id, region=region, config=config, max_retries=max_retries
3134
)
3235

36+
3337
@registry.llm_models("spacy.Bedrock.Titan.Lite.v1")
3438
def titan_lite(
3539
region: str,
3640
model_id: Models = Models.TITAN_LITE,
37-
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE, maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT, stopSequences=_DEFAULT_STOP_SEQUENCES, topP =_DEFAULT_TOP_P),
38-
max_retries: int = _DEFAULT_RETRIES
41+
config: Dict[Any, Any] = SimpleFrozenDict(
42+
temperature=_DEFAULT_TEMPERATURE,
43+
maxTokenCount=_DEFAULT_MAX_TOKEN_COUNT,
44+
stopSequences=_DEFAULT_STOP_SEQUENCES,
45+
topP=_DEFAULT_TOP_P,
46+
),
47+
max_retries: int = _DEFAULT_RETRIES,
3948
) -> Callable[[Iterable[str]], Iterable[str]]:
4049
"""Returns Bedrock instance for 'amazon-titan-lite' model using boto3 to prompt API.
4150
region (str): Specify the AWS region for the service
@@ -44,9 +53,8 @@ def titan_lite(
4453
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
4554
"""
4655
return Bedrock(
47-
model_id = model_id,
48-
region = region,
56+
model_id=model_id,
57+
region=region,
4958
config=config,
5059
max_retries=max_retries,
5160
)
52-

usage_examples/ner_v3_titan/fewshot.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ path = "${paths.examples}"
2929

3030
[components.llm.model]
3131
@llm_models = "spacy.Bedrock.Titan.Express.v1"
32-
region = us-east-1
32+
region = <aws-region>

0 commit comments

Comments
 (0)