Skip to content

Commit d48cca4

Browse files
committed
revert: restore AWS client files to main — no AWS changes in this PR
1 parent 548f741 commit d48cca4

3 files changed

Lines changed: 20 additions & 289 deletions

File tree

src/cohere/aws_client.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
import typing
55

66
import httpx
7-
from httpx import URL, ByteStream
7+
from httpx import URL, SyncByteStream, ByteStream
88

99
from . import GenerateStreamedResponse, Generation, \
1010
NonStreamedChatResponse, EmbedResponse, StreamedChatResponse, RerankResponse, ApiMeta, ApiMetaTokens, \
1111
ApiMetaBilledUnits
1212
from .client import Client, ClientEnvironment
1313
from .core import construct_type
1414
from .manually_maintained.lazy_aws_deps import lazy_boto3, lazy_botocore
15-
from .manually_maintained.streaming import Streamer
1615
from .client_v2 import ClientV2
1716

1817
class AwsClient(Client):
@@ -113,6 +112,16 @@ def get_event_hooks(
113112
})
114113

115114

115+
class Streamer(SyncByteStream):
116+
lines: typing.Iterator[bytes]
117+
118+
def __init__(self, lines: typing.Iterator[bytes]):
119+
self.lines = lines
120+
121+
def __iter__(self) -> typing.Iterator[bytes]:
122+
return self.lines
123+
124+
116125
response_mapping: typing.Dict[str, typing.Any] = {
117126
"chat": NonStreamedChatResponse,
118127
"embed": EmbedResponse,
@@ -230,7 +239,6 @@ def _event_hook(request: httpx.Request) -> None:
230239
)
231240
request.url = URL(url)
232241
request.headers["host"] = request.url.host
233-
headers["host"] = request.url.host
234242

235243
if endpoint == "rerank":
236244
body["api_version"] = get_api_version(version=api_version)
@@ -282,4 +290,4 @@ def get_api_version(*, version: str):
282290
"v2": 2,
283291
}
284292

285-
return int_version.get(version, 1)
293+
return int_version.get(version, 1)

src/cohere/manually_maintained/cohere_aws/client.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,19 @@ class Client:
2020
def __init__(
2121
self,
2222
aws_region: typing.Optional[str] = None,
23-
mode: Mode = Mode.SAGEMAKER,
2423
):
2524
"""
2625
By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
2726
`aws configure set region us-west-2` or override it with `region_name` parameter.
2827
"""
29-
self.mode = mode
28+
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
29+
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
3030
if os.environ.get('AWS_DEFAULT_REGION') is None:
3131
os.environ['AWS_DEFAULT_REGION'] = aws_region
32+
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
33+
self.mode = Mode.SAGEMAKER
3234

33-
if self.mode == Mode.SAGEMAKER:
34-
self._client = lazy_boto3().client("sagemaker-runtime", region_name=aws_region)
35-
self._service_client = lazy_boto3().client("sagemaker", region_name=aws_region)
36-
self._sess = lazy_sagemaker().Session(sagemaker_client=self._service_client)
37-
elif self.mode == Mode.BEDROCK:
38-
self._client = lazy_boto3().client("bedrock-runtime", region_name=aws_region)
39-
self._service_client = lazy_boto3().client("bedrock", region_name=aws_region)
40-
self._sess = None
41-
self._endpoint_name = None
4235

43-
def _require_sagemaker(self) -> None:
44-
if self.mode != Mode.SAGEMAKER:
45-
raise CohereError("This method is only supported in SageMaker mode.")
4636

4737
def _does_endpoint_exist(self, endpoint_name: str) -> bool:
4838
try:
@@ -60,7 +50,6 @@ def connect_to_endpoint(self, endpoint_name: str) -> None:
6050
Raises:
6151
CohereError: Connection to the endpoint failed.
6252
"""
63-
self._require_sagemaker()
6453
if not self._does_endpoint_exist(endpoint_name):
6554
raise CohereError(f"Endpoint {endpoint_name} does not exist.")
6655
self._endpoint_name = endpoint_name
@@ -148,7 +137,6 @@ def create_endpoint(
148137
will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
149138
out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
150139
"""
151-
self._require_sagemaker()
152140
# First, check if endpoint already exists
153141
if self._does_endpoint_exist(endpoint_name):
154142
if recreate:
@@ -562,15 +550,11 @@ def embed(
562550
variant: Optional[str] = None,
563551
input_type: Optional[str] = None,
564552
model_id: Optional[str] = None,
565-
output_dimension: Optional[int] = None,
566-
embedding_types: Optional[List[str]] = None,
567-
) -> Union[Embeddings, Dict[str, List]]:
553+
) -> Embeddings:
568554
json_params = {
569555
'texts': texts,
570556
'truncate': truncate,
571-
"input_type": input_type,
572-
"output_dimension": output_dimension,
573-
"embedding_types": embedding_types,
557+
"input_type": input_type
574558
}
575559
for key, value in list(json_params.items()):
576560
if value is None:
@@ -607,10 +591,7 @@ def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str):
607591
# ValidationError, e.g. when variant is bad
608592
raise CohereError(str(e))
609593

610-
embeddings = response['embeddings']
611-
if isinstance(embeddings, dict):
612-
return embeddings
613-
return Embeddings(embeddings)
594+
return Embeddings(response['embeddings'])
614595

615596
def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
616597
if not model_id:
@@ -631,10 +612,7 @@ def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
631612
# ValidationError, e.g. when variant is bad
632613
raise CohereError(str(e))
633614

634-
embeddings = response['embeddings']
635-
if isinstance(embeddings, dict):
636-
return embeddings
637-
return Embeddings(embeddings)
615+
return Embeddings(response['embeddings'])
638616

639617

640618
def rerank(self,
@@ -827,7 +805,6 @@ def export_finetune(
827805
This should work when one uses the client inside SageMaker. If this errors out,
828806
the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker.
829807
"""
830-
self._require_sagemaker()
831808
if name == "model":
832809
raise ValueError("name cannot be 'model'")
833810

@@ -971,7 +948,6 @@ def summarize(
971948
additional_command: Optional[str] = "",
972949
variant: Optional[str] = None
973950
) -> Summary:
974-
self._require_sagemaker()
975951

976952
if self._endpoint_name is None:
977953
raise CohereError("No endpoint connected. "
@@ -1013,7 +989,6 @@ def summarize(
1013989

1014990

1015991
def delete_endpoint(self) -> None:
1016-
self._require_sagemaker()
1017992
if self._endpoint_name is None:
1018993
raise CohereError("No endpoint connected.")
1019994
try:

0 commit comments

Comments
 (0)