@@ -20,29 +20,19 @@ class Client:
2020 def __init__ (
2121 self ,
2222 aws_region : typing .Optional [str ] = None ,
23- mode : Mode = Mode .SAGEMAKER ,
2423 ):
2524 """
2625 By default we assume region configured in AWS CLI (`aws configure get region`). You can change the region with
2726 `aws configure set region us-west-2` or override it with `region_name` parameter.
2827 """
29- self .mode = mode
28+ self ._client = lazy_boto3 ().client ("sagemaker-runtime" , region_name = aws_region )
29+ self ._service_client = lazy_boto3 ().client ("sagemaker" , region_name = aws_region )
3030 if os .environ .get ('AWS_DEFAULT_REGION' ) is None :
3131 os .environ ['AWS_DEFAULT_REGION' ] = aws_region
32+ self ._sess = lazy_sagemaker ().Session (sagemaker_client = self ._service_client )
33+ self .mode = Mode .SAGEMAKER
3234
33- if self .mode == Mode .SAGEMAKER :
34- self ._client = lazy_boto3 ().client ("sagemaker-runtime" , region_name = aws_region )
35- self ._service_client = lazy_boto3 ().client ("sagemaker" , region_name = aws_region )
36- self ._sess = lazy_sagemaker ().Session (sagemaker_client = self ._service_client )
37- elif self .mode == Mode .BEDROCK :
38- self ._client = lazy_boto3 ().client ("bedrock-runtime" , region_name = aws_region )
39- self ._service_client = lazy_boto3 ().client ("bedrock" , region_name = aws_region )
40- self ._sess = None
41- self ._endpoint_name = None
4235
43- def _require_sagemaker (self ) -> None :
44- if self .mode != Mode .SAGEMAKER :
45- raise CohereError ("This method is only supported in SageMaker mode." )
4636
4737 def _does_endpoint_exist (self , endpoint_name : str ) -> bool :
4838 try :
@@ -60,7 +50,6 @@ def connect_to_endpoint(self, endpoint_name: str) -> None:
6050 Raises:
6151 CohereError: Connection to the endpoint failed.
6252 """
63- self ._require_sagemaker ()
6453 if not self ._does_endpoint_exist (endpoint_name ):
6554 raise CohereError (f"Endpoint { endpoint_name } does not exist." )
6655 self ._endpoint_name = endpoint_name
@@ -148,7 +137,6 @@ def create_endpoint(
148137 will be used to get the role. This should work when one uses the client inside SageMaker. If this errors
149138 out, the default role "ServiceRoleSagemaker" will be used, which generally works outside of SageMaker.
150139 """
151- self ._require_sagemaker ()
152140 # First, check if endpoint already exists
153141 if self ._does_endpoint_exist (endpoint_name ):
154142 if recreate :
@@ -562,15 +550,11 @@ def embed(
562550 variant : Optional [str ] = None ,
563551 input_type : Optional [str ] = None ,
564552 model_id : Optional [str ] = None ,
565- output_dimension : Optional [int ] = None ,
566- embedding_types : Optional [List [str ]] = None ,
567- ) -> Union [Embeddings , Dict [str , List ]]:
553+ ) -> Embeddings :
568554 json_params = {
569555 'texts' : texts ,
570556 'truncate' : truncate ,
571- "input_type" : input_type ,
572- "output_dimension" : output_dimension ,
573- "embedding_types" : embedding_types ,
557+ "input_type" : input_type
574558 }
575559 for key , value in list (json_params .items ()):
576560 if value is None :
@@ -607,10 +591,7 @@ def _sagemaker_embed(self, json_params: Dict[str, Any], variant: str):
607591 # ValidationError, e.g. when variant is bad
608592 raise CohereError (str (e ))
609593
610- embeddings = response ['embeddings' ]
611- if isinstance (embeddings , dict ):
612- return embeddings
613- return Embeddings (embeddings )
594+ return Embeddings (response ['embeddings' ])
614595
615596 def _bedrock_embed (self , json_params : Dict [str , Any ], model_id : str ):
616597 if not model_id :
@@ -631,10 +612,7 @@ def _bedrock_embed(self, json_params: Dict[str, Any], model_id: str):
631612 # ValidationError, e.g. when variant is bad
632613 raise CohereError (str (e ))
633614
634- embeddings = response ['embeddings' ]
635- if isinstance (embeddings , dict ):
636- return embeddings
637- return Embeddings (embeddings )
615+ return Embeddings (response ['embeddings' ])
638616
639617
640618 def rerank (self ,
@@ -827,7 +805,6 @@ def export_finetune(
827805 This should work when one uses the client inside SageMaker. If this errors out,
828806 the default role "ServiceRoleSagemaker" will be used, which generally works outside SageMaker.
829807 """
830- self ._require_sagemaker ()
831808 if name == "model" :
832809 raise ValueError ("name cannot be 'model'" )
833810
@@ -971,7 +948,6 @@ def summarize(
971948 additional_command : Optional [str ] = "" ,
972949 variant : Optional [str ] = None
973950 ) -> Summary :
974- self ._require_sagemaker ()
975951
976952 if self ._endpoint_name is None :
977953 raise CohereError ("No endpoint connected. "
@@ -1013,7 +989,6 @@ def summarize(
1013989
1014990
1015991 def delete_endpoint (self ) -> None :
1016- self ._require_sagemaker ()
1017992 if self ._endpoint_name is None :
1018993 raise CohereError ("No endpoint connected." )
1019994 try :
0 commit comments