|
| 1 | +import os |
| 2 | + |
| 3 | +import boto3 |
| 4 | + |
| 5 | +from modelgauge.dynamic_sut_factory import DynamicSUTFactory, ModelNotSupportedError |
| 6 | +from modelgauge.secret_values import InjectSecret, RawSecrets |
| 7 | +from modelgauge.sut import SUT |
| 8 | +from modelgauge.sut_definition import SUTDefinition |
| 9 | +from modelgauge.suts.aws_bedrock_client import AmazonBedrockSut, AwsAccessKeyId, AwsSecretAccessKey |
| 10 | + |
| 11 | + |
| 12 | +class AWSBedrockSUTFactory(DynamicSUTFactory): |
| 13 | + DRIVER_NAME = "aws" |
| 14 | + |
| 15 | + def __init__(self, raw_secrets: RawSecrets): |
| 16 | + super().__init__(raw_secrets) |
| 17 | + self._client = None # Lazy load. |
| 18 | + |
| 19 | + @property |
| 20 | + def client(self): |
| 21 | + if self._client is None: |
| 22 | + self._client = boto3.client( |
| 23 | + service_name="bedrock", |
| 24 | + region_name=os.getenv("AWS_REGION", "us-east-1"), |
| 25 | + aws_access_key_id=self.injected_secrets()[0].value, |
| 26 | + aws_secret_access_key=self.injected_secrets()[1].value, |
| 27 | + ) |
| 28 | + return self._client |
| 29 | + |
| 30 | + def _convert_model_id(self, model_id: str) -> SUTDefinition: |
| 31 | + """Convert AWS model IDs (maker.model[:version?]) to our standard format.""" |
| 32 | + maker, model_name = model_id.split(".", maxsplit=1) |
| 33 | + model_name = model_name.replace(":", ".") |
| 34 | + return SUTDefinition({"maker": maker, "model": model_name, "driver": self.DRIVER_NAME}) |
| 35 | + |
| 36 | + def _get_available_models(self, maker: str): |
| 37 | + response = self.client.list_foundation_models() |
| 38 | + models = {} |
| 39 | + for m in response["modelSummaries"]: |
| 40 | + if m.get("modelLifecycle", {}).get("status") != "ACTIVE": |
| 41 | + continue |
| 42 | + models[m["modelId"]] = self._convert_model_id(m["modelId"]) |
| 43 | + return models |
| 44 | + |
| 45 | + def _get_model_id(self, sut_definition: SUTDefinition): |
| 46 | + models = self._get_available_models(sut_definition.to_dynamic_sut_metadata().maker) |
| 47 | + for model_id, model_definition in models.items(): |
| 48 | + if str(model_definition.to_dynamic_sut_metadata()) == str(sut_definition.to_dynamic_sut_metadata()): |
| 49 | + return model_id |
| 50 | + supported_models = [model_def.to_dynamic_sut_metadata().external_model_name() for model_def in models.values()] |
| 51 | + raise ModelNotSupportedError( |
| 52 | + f"Model {sut_definition.external_model_name()} not found among AWS Bedrock models. AWS carries the following models from maker {sut_definition.get("maker")}: {supported_models} " |
| 53 | + ) |
| 54 | + |
| 55 | + def get_secrets(self) -> list[InjectSecret]: |
| 56 | + return [InjectSecret(AwsAccessKeyId), InjectSecret(AwsSecretAccessKey)] |
| 57 | + |
| 58 | + def make_sut(self, sut_definition: SUTDefinition) -> SUT: |
| 59 | + model_id = self._get_model_id(sut_definition) |
| 60 | + return AmazonBedrockSut(sut_definition.dynamic_uid, model_id, *self.injected_secrets()) |
0 commit comments