diff --git a/assets/common/components/delete_endpoint/spec.yaml b/assets/common/components/delete_endpoint/spec.yaml new file mode 100644 index 0000000000..9e55975a1f --- /dev/null +++ b/assets/common/components/delete_endpoint/spec.yaml @@ -0,0 +1,40 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +name: delete_endpoint +version: 0.0.7 +type: command + +is_deterministic: True + +display_name: Delete Endpoint +description: + Deletes an endpoint resource. + +environment: azureml://registries/azureml/environments/python-sdk-v2/versions/28 + +code: ../../src +command: >- + python delete_endpoint.py + $[[--model_deployment_details ${{inputs.model_deployment_details}}]] + $[[--endpoint_name ${{inputs.endpoint_name}}]] + $[[--deployment_name ${{inputs.deployment_name}}]] + +inputs: + # Output of registering component + model_deployment_details: + type: uri_file + optional: true + description: JSON file that contains the deployment details. + + endpoint_name: + type: string + optional: true + description: Name of the endpoint to delete. + + deployment_name: + type: string + optional: true + description: Name of the deployment to delete. + +tags: + Preview: "" + Internal: "" \ No newline at end of file diff --git a/assets/common/components/deploy_inference_model/spec.yaml b/assets/common/components/deploy_inference_model/spec.yaml new file mode 100644 index 0000000000..407c0a2416 --- /dev/null +++ b/assets/common/components/deploy_inference_model/spec.yaml @@ -0,0 +1,234 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +name: deploy_inference_model +version: 0.0.1 +type: command + +is_deterministic: True + +display_name: Deploy model +description: + Deploy a model to a workspace. The component works on compute with [MSI](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-create-manage-compute-instance?tabs=python) attached. + +environment: azureml://registries/azureml/environments/python-sdk-v2/versions/29 + +code: ../../src +command: >- + python deploy_inference_model.py + $[[--registration_details_folder ${{inputs.registration_details_folder}}]] + $[[--model_id ${{inputs.model_id}}]] + $[[--environment_id ${{inputs.environment_id}}]] + $[[--inference_payload ${{inputs.inference_payload}}]] + $[[--inference_payload_str ${{inputs.inference_payload_str}}]] + $[[--endpoint_name ${{inputs.endpoint_name}}]] + $[[--deployment_name ${{inputs.deployment_name}}]] + $[[--instance_type ${{inputs.instance_type}}]] + $[[--instance_count ${{inputs.instance_count}}]] + $[[--max_concurrent_requests_per_instance ${{inputs.max_concurrent_requests_per_instance}}]] + $[[--request_timeout_ms ${{inputs.request_timeout_ms}}]] + $[[--max_queue_wait_ms ${{inputs.max_queue_wait_ms}}]] + $[[--failure_threshold_readiness_probe ${{inputs.failure_threshold_readiness_probe}}]] + $[[--success_threshold_readiness_probe ${{inputs.success_threshold_readiness_probe}}]] + $[[--timeout_readiness_probe ${{inputs.timeout_readiness_probe}}]] + $[[--period_readiness_probe ${{inputs.period_readiness_probe}}]] + $[[--initial_delay_readiness_probe ${{inputs.initial_delay_readiness_probe}}]] + $[[--failure_threshold_liveness_probe ${{inputs.failure_threshold_liveness_probe}}]] + $[[--timeout_liveness_probe ${{inputs.timeout_liveness_probe}}]] + $[[--period_liveness_probe ${{inputs.period_liveness_probe}}]] + $[[--initial_delay_liveness_probe ${{inputs.initial_delay_liveness_probe}}]] + $[[--egress_public_network_access ${{inputs.egress_public_network_access}}]] + --model_deployment_details ${{outputs.model_deployment_details}} + --model_inference_response ${{outputs.model_inference_response}} + --deploy_error ${{outputs.deploy_error}} + +inputs: + # Output of registering component + registration_details_folder: + type: uri_folder + optional: true + description: Folder containing model registration details in a JSON file named model_registration_details.json + + model_id: + type: string + optional: true + description: | + Asset ID of the model registered in workspace/registry. + Registry - azureml://registries//models//versions/ + Workspace - azureml:: + + environment_id: + type: string + optional: true + description: | + Asset ID of the environment registered in workspace/registry. + Registry - azureml://registries//environments//versions/ + Workspace - azureml:: + + inference_payload: + type: uri_file + optional: true + description: JSON payload which would be used to validate deployment + + inference_payload_str: + type: string + optional: true + description: Serialized JSON payload which would be used to validate deployment + + endpoint_name: + type: string + optional: true + description: Name of the endpoint + + deployment_name: + type: string + optional: true + default: default + description: Name of the deployment + + instance_type: + type: string + optional: true + enum: + - Standard_DS1_v2 + - Standard_DS2_v2 + - Standard_DS3_v2 + - Standard_DS4_v2 + - Standard_DS5_v2 + - Standard_F2s_v2 + - Standard_F4s_v2 + - Standard_F8s_v2 + - Standard_F16s_v2 + - Standard_F32s_v2 + - Standard_F48s_v2 + - Standard_F64s_v2 + - Standard_F72s_v2 + - Standard_FX24mds + - Standard_FX36mds + - Standard_FX48mds + - Standard_E2s_v3 + - Standard_E4s_v3 + - Standard_E8s_v3 + - Standard_E16s_v3 + - Standard_E32s_v3 + - Standard_E48s_v3 + - Standard_E64s_v3 + - Standard_NC4as_T4_v3 + - Standard_NC6s_v2 + - Standard_NC6s_v3 + - Standard_NC8as_T4_v3 + - Standard_NC12s_v2 + - Standard_NC12s_v3 + - Standard_NC16as_T4_v3 + - Standard_NC24s_v2 + - Standard_NC24s_v3 + - Standard_NC24rs_v3 + - Standard_NC24ads_A100_v4 + - Standard_NC48ads_A100_v4 + - Standard_NC96ads_A100_v4 + - Standard_NC64as_T4_v3 + - Standard_ND40rs_v2 + - Standard_ND96asr_v4 + - Standard_ND96amsr_A100_v4 + default: Standard_NC24s_v3 + description: Compute instance type to deploy model. Make sure that instance type is available and have enough quota available. + + instance_count: + type: integer + optional: true + default: 1 + description: Number of instances you want to use for deployment. Make sure instance type have enough quota available. + + max_concurrent_requests_per_instance: + type: integer + default: 1 + optional: true + description: Maximum concurrent requests to be handled per instance + + request_timeout_ms: + type: integer + default: 60000 + optional: true + description: Request timeout in ms. Max limit is 90000. + + max_queue_wait_ms: + type: integer + default: 60000 + optional: true + description: Maximum queue wait time of a request in ms + + failure_threshold_readiness_probe: + type: integer + default: 10 + optional: true + description: The number of times system will try after failing the readiness probe + + success_threshold_readiness_probe: + type: integer + default: 1 + optional: true + description: The minimum consecutive successes for the readiness probe to be considered successful after having failed + + timeout_readiness_probe: + type: integer + default: 10 + optional: true + description: The number of seconds after which the readiness probe times out + + period_readiness_probe: + type: integer + default: 10 + optional: true + description: How often (in seconds) to perform the readiness probe + + initial_delay_readiness_probe: + type: integer + default: 10 + optional: true + description: The number of seconds after the container has started before the readiness probe is initiated + + failure_threshold_liveness_probe: + type: integer + default: 30 + optional: true + description: The number of times system will try after failing the liveness probe + + timeout_liveness_probe: + type: integer + default: 10 + optional: true + description: The number of seconds after which the liveness probe times out + + period_liveness_probe: + type: integer + default: 10 + optional: true + description: How often (in seconds) to perform the liveness probe + + initial_delay_liveness_probe: + type: integer + default: 10 + optional: true + description: The number of seconds after the container has started before the liveness probe is initiated + + egress_public_network_access: + type: string + default: enabled + optional: true + enum: + - enabled + - disabled + description: Setting it to disabled secures the deployment by restricting communication between the deployment and the Azure resources used by it + +outputs: + model_deployment_details: + type: uri_file + description: Json file to which deployment details will be written + model_inference_response: + type: uri_file + description: JSON file containing inference results + deploy_error: + type: uri_file + description: File containing error messages or stack traces from the validation step. + +tags: + Preview: "" + Internal: "" diff --git a/assets/common/src/deploy.py b/assets/common/src/deploy.py index a4afa64493..c0558feebe 100644 --- a/assets/common/src/deploy.py +++ b/assets/common/src/deploy.py @@ -338,4 +338,4 @@ def main(): # run script if __name__ == "__main__": # run main function - main() + main() \ No newline at end of file diff --git a/assets/common/src/deploy_inference_model.py b/assets/common/src/deploy_inference_model.py new file mode 100644 index 0000000000..0fa2f68be7 --- /dev/null +++ b/assets/common/src/deploy_inference_model.py @@ -0,0 +1,429 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Run Model deployment module.""" +import argparse +import json +import re +import time +import base64 +import traceback + +from azure.ai.ml.entities import ( + ManagedOnlineEndpoint, + ManagedOnlineDeployment, + OnlineRequestSettings, + ProbeSettings, +) +from azureml._common._error_definition import AzureMLError +from azureml._common.exceptions import AzureMLException +from pathlib import Path + +from utils.config import AppName, ComponentVariables +from utils.common_utils import get_mlclient, get_model_name +from utils.logging_utils import custom_dimensions, get_logger +from utils.exceptions import ( + swallow_all_exceptions, + OnlineEndpointInvocationError, + EndpointCreationError, + DeploymentCreationError, +) + + +MAX_REQUEST_TIMEOUT = 90000 +MAX_INSTANCE_COUNT = 20 +MAX_DEPLOYMENT_LOG_TAIL_LINES = 10000 + +logger = get_logger(__name__) +custom_dimensions.app_name = AppName.DEPLOY_MODEL + + +def parse_args(): + """Return arguments.""" + parser = argparse.ArgumentParser() + + # Defaults for managed online endpoint has been picked mostly from: + # https://learn.microsoft.com/en-us/azure/machine-learning/reference-yaml-deployment-managed-online + # Some of the defaults have been tweaked to cater to large models. + + # add arguments + parser.add_argument( + "--registration_details_folder", + type=Path, + help="Folder containing model registration details in a JSON file named model_registration_details.json", + ) + parser.add_argument( + "--model_id", + type=str, + help="Registered mlflow model id", + ) + parser.add_argument( + "--environment_id", + type=str, + required=False, + help="AzureML environment ID to use for deployment", + ) + parser.add_argument( + "--inference_payload", + type=Path, + help="Json file with inference endpoint payload.", + ) + parser.add_argument( + "--inference_payload_str", + type=str, + help="Serialized JSON payload for inference.", + ) + parser.add_argument( + "--endpoint_name", + type=str, + help="Name of the endpoint", + ) + parser.add_argument("--deployment_name", type=str, help="Name of the the deployment") + parser.add_argument( + "--instance_type", + type=str, + help="Compute instance type to deploy model", + default="Standard_NC24s_v3", + ) + parser.add_argument( + "--instance_count", + type=int, + help="Number of compute instances to deploy model", + default=1, + choices=range(1, MAX_INSTANCE_COUNT), + ) + parser.add_argument( + "--max_concurrent_requests_per_instance", + type=int, + default=1, + help="Maximum concurrent requests to be handled per instance", + ) + parser.add_argument( + "--request_timeout_ms", + type=int, + default=60000, # 1min + help="Request timeout in ms.", + ) + parser.add_argument( + "--max_queue_wait_ms", + type=int, + default=60000, # 1min + help="Maximum queue wait time of a request in ms", + ) + parser.add_argument( + "--failure_threshold_readiness_probe", + type=int, + default=10, + help="No of times system will try after failing the readiness probe", + ) + parser.add_argument( + "--success_threshold_readiness_probe", + type=int, + default=1, + help="The minimum consecutive successes for the readiness probe to be considered successful, after fail", + ) + parser.add_argument( + "--timeout_readiness_probe", + type=int, + default=10, + help="The number of seconds after which the readiness probe times out", + ) + parser.add_argument( + "--period_readiness_probe", + type=int, + default=10, + help="How often (in seconds) to perform the readiness probe", + ) + parser.add_argument( + "--initial_delay_readiness_probe", + type=int, + default=10, + help="The number of seconds after the container has started before the readiness probe is initiated", + ) + parser.add_argument( + "--failure_threshold_liveness_probe", + type=int, + default=30, + help="No of times system will try after failing the liveness probe", + ) + parser.add_argument( + "--timeout_liveness_probe", + type=int, + default=10, + help="The number of seconds after which the liveness probe times out", + ) + parser.add_argument( + "--period_liveness_probe", + type=int, + default=10, + help="How often (in seconds) to perform the liveness probe", + ) + parser.add_argument( + "--initial_delay_liveness_probe", + type=int, + default=10, + help="The number of seconds after the container has started before the liveness probe is initiated", + ) + parser.add_argument( + "--egress_public_network_access", + type=str, + default="enabled", + help="Secures the deployment by restricting interaction between deployment and Azure resources used by it", + ) + parser.add_argument( + "--model_deployment_details", + type=str, + help="Json file to which deployment details will be written", + ) + parser.add_argument( + "--model_inference_response", + type=str, + help="Path to the inference response JSON file.", + ) + parser.add_argument( + "--deploy_error", + type=str, + help="Path to the inference response JSON file.", + ) + # parse args + args = parser.parse_args() + logger.info(f"Args received {args}") + print("args received ", args) + + # Validating passed input values + if args.max_concurrent_requests_per_instance < 1: + raise Exception("Arg max_concurrent_requests_per_instance cannot be less than 1") + if args.request_timeout_ms < 1 or args.request_timeout_ms > MAX_REQUEST_TIMEOUT: + raise Exception(f"Arg request_timeout_ms should lie between 1 and {MAX_REQUEST_TIMEOUT}") + if args.max_queue_wait_ms < 1 or args.max_queue_wait_ms > MAX_REQUEST_TIMEOUT: + raise Exception(f"Arg max_queue_wait_ms should lie between 1 and {MAX_REQUEST_TIMEOUT}") + + return args + + +def create_endpoint_and_deployment(ml_client, model_id, environment_id, endpoint_name, deployment_name, args): + """Create endpoint and deployment and return details.""" + endpoint = ManagedOnlineEndpoint(name=endpoint_name, auth_mode="aad_token") + + # deployment + deployment = ManagedOnlineDeployment( + name=deployment_name, + endpoint_name=endpoint_name, + model=model_id, + environment=environment_id, + instance_type=args.instance_type, + instance_count=args.instance_count, + request_settings=OnlineRequestSettings( + max_concurrent_requests_per_instance=args.max_concurrent_requests_per_instance, + request_timeout_ms=args.request_timeout_ms, + max_queue_wait_ms=args.max_queue_wait_ms, + ), + liveness_probe=ProbeSettings( + failure_threshold=args.failure_threshold_liveness_probe, + timeout=args.timeout_liveness_probe, + period=args.period_liveness_probe, + initial_delay=args.initial_delay_liveness_probe, + ), + readiness_probe=ProbeSettings( + failure_threshold=args.failure_threshold_readiness_probe, + success_threshold=args.success_threshold_readiness_probe, + timeout=args.timeout_readiness_probe, + period=args.period_readiness_probe, + initial_delay=args.initial_delay_readiness_probe, + ), + egress_public_network_access=args.egress_public_network_access, + ) + + try: + logger.info(f"Creating endpoint {endpoint_name}") + ml_client.begin_create_or_update(endpoint).wait() + endpoint = ml_client.online_endpoints.get(endpoint.name) + logger.info(f"Endpoint created {endpoint.id}") + except Exception as e: + raise AzureMLException._with_error( + AzureMLError.create(EndpointCreationError, exception=e) + ) + + try: + logger.info(f"Creating deployment {deployment}") + ml_client.online_deployments.begin_create_or_update(deployment).wait() + except Exception as e: + try: + logger.error("Deployment failed. Printing deployment logs") + logs = ml_client.online_deployments.get_logs( + name=deployment_name, + endpoint_name=endpoint_name, + lines=MAX_DEPLOYMENT_LOG_TAIL_LINES + ) + logger.error(logs) + except Exception as ex: + logger.error(f"Error in fetching deployment logs: {ex}") + + raise AzureMLException._with_error( + AzureMLError.create(DeploymentCreationError, exception=e) + ) + + logger.info(f"Deployment successful. Updating endpoint to take 100% traffic for deployment {deployment_name}") + + # deployment to take 100% traffic + endpoint.traffic = {deployment.name: 100} + try: + ml_client.begin_create_or_update(endpoint).wait() + endpoint = ml_client.online_endpoints.get(endpoint.name) + except Exception as e: + error_msg = f"Error occured while updating endpoint traffic. Deployment should be usable. Exception - {e}" + raise Exception(error_msg) + + logger.info(f"Endpoint updated to take 100% traffic for deployment {deployment_name}") + return endpoint, deployment + + +@swallow_all_exceptions(logger) +def main(): + """Run main function.""" + try: + args = parse_args() + logger.info(f"Arguments: {args}") + ml_client = get_mlclient() + + error_message = "" + if args.model_deployment_details: + with open(args.model_deployment_details, "w") as outfile: + json.dump({}, outfile) + + if args.model_inference_response: + with open(args.model_inference_response, "w") as f: + json.dump({}, f, indent=4) + + if args.deploy_error: + with open(args.deploy_error, "w") as error_file: + error_file.write(error_message) + + # get environment id + environment_id = args.environment_id if hasattr(args, "environment_id") else None + + # get registered model id + if args.model_id: + model_id = str(args.model_id) + elif args.registration_details_folder: + registration_details_file = args.registration_details_folder/ComponentVariables.REGISTRATION_DETAILS_JSON_FILE + if registration_details_file.exists(): + try: + with open(registration_details_file) as f: + model_info = json.load(f) + model_id = model_info["id"] + except Exception as e: + raise Exception(f"model_registration_details json file is missing model information {e}.") + else: + raise Exception(f"{ComponentVariables.REGISTRATION_DETAILS_JSON_FILE} is missing inside folder.") + else: + raise Exception("Arguments model_id and registration_details both are missing.") + + # Endpoint has following restrictions: + # 1. Name must begin with lowercase letter + # 2. Followed by lowercase letters, hyphen or numbers + # 3. End with a lowercase letter or number + + # 1. Replace underscores and slashes by hyphens and convert them to lower case. + # 2. Take 21 chars from model name and append '-' & timstamp(10chars) to it + model_name = get_model_name(model_id) + + endpoint_name = re.sub("[^A-Za-z0-9]", "-", model_name).lower()[:21] + endpoint_name = f"{endpoint_name}-{int(time.time())}" + endpoint_name = endpoint_name + + endpoint_name = args.endpoint_name if args.endpoint_name else endpoint_name + deployment_name = args.deployment_name if args.deployment_name else "default" + + endpoint, deployment = create_endpoint_and_deployment( + ml_client=ml_client, + endpoint_name=endpoint_name, + deployment_name=deployment_name, + model_id=model_id, + environment_id=environment_id, + args=args + ) + + response = None + if args.inference_payload or args.inference_payload_str: + print("Invoking inference with test payload ...") + try: + start_time = time.time() + if args.inference_payload_str: + print(f"Inference payload string: {args.inference_payload_str}") + decoded_bytes = base64.b64decode(args.inference_payload_str) + + # Convert bytes to string + decoded_str = decoded_bytes.decode('utf-8') + logger.info(f"Decoded string: {decoded_str}") + + payload = json.loads(decoded_str) + logger.info(f"Payload:\n {payload}") + + with open("payload.json", "w") as temp_file: + json.dump(payload, temp_file) + + response = ml_client.online_endpoints.invoke( + endpoint_name=endpoint_name, + deployment_name=deployment_name, + request_file="payload.json", + ) + elif args.inference_payload: + response = ml_client.online_endpoints.invoke( + endpoint_name=endpoint_name, + deployment_name=deployment_name, + request_file=args.inference_payload, + ) + + end_time = time.time() + inference_time_ms = int((end_time - start_time) * 1000) + + logger.info(f"Endpoint invoked successfully with inference time :{inference_time_ms} ms " + + f"and response: {response}") + # Save inference response + if args.model_inference_response: + inference_result = { + "response": response, + "inference_time": inference_time_ms + } + with open(args.model_inference_response, "w") as f: + json.dump(inference_result, f, indent=4) + logger.info(f"Saved inference response and inference time to output JSON file: {inference_result}") + except Exception as e: + raise AzureMLException._with_error( + AzureMLError.create(OnlineEndpointInvocationError, exception=e) + ) + + print("Saving deployment details ...") + + # write deployment details to file + endpoint_type = "aml_online_inference" + deployment_details = { + "endpoint_name": endpoint.name, + "deployment_name": deployment.name, + "endpoint_uri": endpoint.__dict__["_scoring_uri"], + "endpoint_type": endpoint_type, + "instance_type": args.instance_type, + "instance_count": args.instance_count, + "max_concurrent_requests_per_instance": args.max_concurrent_requests_per_instance, + } + json_object = json.dumps(deployment_details, indent=4) + with open(args.model_deployment_details, "w") as outfile: + outfile.write(json_object) + logger.info("Saved deployment details in output json file.") + + except Exception as e: + # Capture the full traceback + stack_trace = traceback.format_exc() + error_message = f"Model deployment failed.\n{stack_trace}" + logger.error(f"error_message: {error_message}, deploy_error_path: {args.deploy_error}") + + # Write the error message to the specified error output file + if args.deploy_error: + with open(args.deploy_error, "w") as error_file: + error_file.write(error_message) + + +if __name__ == "__main__": + # run main function + main() diff --git a/assets/training/model_management/components/publish_validation_results_selfserve/asset.yaml b/assets/training/model_management/components/publish_validation_results_selfserve/asset.yaml new file mode 100644 index 0000000000..9d4136ecd3 --- /dev/null +++ b/assets/training/model_management/components/publish_validation_results_selfserve/asset.yaml @@ -0,0 +1,11 @@ +type: component +spec: spec.yaml +categories: + [ + "CommonBench Baselining", + "Benchmarking", + "Run Benchmark", + "Publish Results", + "Self-Serve API", + "API Inferencing" + ] diff --git a/assets/training/model_management/components/publish_validation_results_selfserve/spec.yaml b/assets/training/model_management/components/publish_validation_results_selfserve/spec.yaml new file mode 100644 index 0000000000..db5ab152ff --- /dev/null +++ b/assets/training/model_management/components/publish_validation_results_selfserve/spec.yaml @@ -0,0 +1,61 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +type: command +is_deterministic: true + +name: publish_validation_results_selfserve +version: 0.0.1 +display_name: Publish model validation results to Self-Serve +description: | + This component publishes model validation results to the Self-Serve database. + +environment: azureml://registries/azureml/environments/model-management/versions/47 + +inputs: + selfserve_base_url: + type: string + optional: false + default: "https://int.api.azureml-test.ms" + description: Base URL of the model publisher self-serve API + model_name: + type: string + optional: false + description: Name of the model (e.g., VerboGenie) + model_version: + type: integer + optional: false + description: Model onboarding version (e.g., 5) + publisher_name: + type: string + optional: false + description: Name of the model publisher (e.g., ContosoAI) + sku: + type: string + optional: false + default: "Standard_NC24ads_A100_v4" + description: Suggested SKU based on benchmark results + validation_id: + type: string + optional: false + description: ID of the validation run (used for updating status in self-serve) + metrics_storage_uri: + type: uri_file + optional: true + mode: ro_mount + description: Path to the file containing the validation metrics csv storage path + validation_error: + type: uri_file + optional: true + description: Error message or stack trace from the inference validation step + +code: ../../src + +command: >- + python publish_validation_results_selfserve.py + --selfserve-base-url ${{inputs.selfserve_base_url}} + --model-name ${{inputs.model_name}} + --model-version ${{inputs.model_version}} + --publisher-name ${{inputs.publisher_name}} + --validation-id ${{inputs.validation_id}} + --sku ${{inputs.sku}} + $[[ --metrics-storage-uri ${{inputs.metrics_storage_uri}}]] + $[[ --validation-error ${{inputs.validation_error}}]] \ No newline at end of file diff --git a/assets/training/model_management/components/run_inference_validation/asset.yaml b/assets/training/model_management/components/run_inference_validation/asset.yaml new file mode 100644 index 0000000000..c01772d398 --- /dev/null +++ b/assets/training/model_management/components/run_inference_validation/asset.yaml @@ -0,0 +1,3 @@ +type: component +spec: spec.yaml +categories: ["Model"] diff --git a/assets/training/model_management/components/run_inference_validation/spec.yaml b/assets/training/model_management/components/run_inference_validation/spec.yaml new file mode 100644 index 0000000000..3feba65c60 --- /dev/null +++ b/assets/training/model_management/components/run_inference_validation/spec.yaml @@ -0,0 +1,70 @@ +$schema: https://azuremlschemas.azureedge.net/latest/commandComponent.schema.json +name: run_inference_validation +version: 0.0.1 +type: command + +is_deterministic: True + +display_name: Run Inference Validation +description: Compares the expected inference response with the actual response from model deployment. + +environment: azureml://registries/azureml/environments/model-management/versions/47 + +code: ../../src +command: >- + python run_inference_validation.py + --inference_payload ${{inputs.inference_payload}} + $[[--expected_response ${{inputs.expected_response}}]] + $[[--inference_response ${{inputs.inference_response}}]] + $[[--deployment_error ${{inputs.deployment_error}}]] + --validation-id ${{inputs.validation_id}} + --sku ${{inputs.sku}} + --validation_results ${{outputs.validation_results}} + --metrics_storage_uri ${{outputs.metrics_storage_uri}} + --validation_error ${{outputs.validation_error}} + +inputs: + inference_payload: + type: string + description: JSON input payload used for inference. + + expected_response: + type: string + optional: true + description: JSON file containing the expected inference response. + + inference_response: + type: uri_file + optional: true + description: JSON file containing the actual inference response from the deployed model. + + sku: + type: string + optional: false + default: "Standard_NC24ads_A100_v4" + description: Suggested SKU based on benchmark results + + validation_id: + type: string + optional: false + description: ID of the validation run (used for updating status in self-serve) + + deployment_error: + type: uri_file + optional: true + description: Error message or stack trace from the inference validation step + +outputs: + validation_results: + type: uri_folder + description: JSON file containing the validation results. + metrics_storage_uri: + type: uri_file + description: JSON file containing the validation metrics csv storage path + validation_error: + type: uri_file + description: File containing error messages or stack traces from the validation step. + +tags: + Preview: "" + Internal: "" diff --git a/assets/training/model_management/components/validate_model_inference/asset.yaml b/assets/training/model_management/components/validate_model_inference/asset.yaml new file mode 100644 index 0000000000..c01772d398 --- /dev/null +++ b/assets/training/model_management/components/validate_model_inference/asset.yaml @@ -0,0 +1,3 @@ +type: component +spec: spec.yaml +categories: ["Model"] diff --git a/assets/training/model_management/components/validate_model_inference/spec.yaml b/assets/training/model_management/components/validate_model_inference/spec.yaml new file mode 100644 index 0000000000..49524c34c8 --- /dev/null +++ b/assets/training/model_management/components/validate_model_inference/spec.yaml @@ -0,0 +1,198 @@ +$schema: https://azuremlschemas.azureedge.net/latest/pipelineComponent.schema.json +type: pipeline + +name: validate_model_inference +display_name: Validate Model Inference +description: deploy a model and validate it using a sample payload +version: 0.0.1 + +inputs: + compute: + type: string + optional: true + default: serverless + description: Compute for model deployment and inferencing + + instance_type: + type: string + optional: true + enum: + - Standard_DS1_v2 + - Standard_DS2_v2 + - Standard_DS3_v2 + - Standard_DS4_v2 + - Standard_DS5_v2 + - Standard_F2s_v2 + - Standard_F4s_v2 + - Standard_F8s_v2 + - Standard_F16s_v2 + - Standard_F32s_v2 + - Standard_F48s_v2 + - Standard_F64s_v2 + - Standard_F72s_v2 + - Standard_FX24mds + - Standard_FX36mds + - Standard_FX48mds + - Standard_E2s_v3 + - Standard_E4s_v3 + - Standard_E8s_v3 + - Standard_E16s_v3 + - Standard_E32s_v3 + - Standard_E48s_v3 + - Standard_E64s_v3 + - Standard_NC4as_T4_v3 + - Standard_NC6s_v2 + - Standard_NC6s_v3 + - Standard_NC8as_T4_v3 + - Standard_NC12s_v2 + - Standard_NC12s_v3 + - Standard_NC16as_T4_v3 + - Standard_NC24s_v2 + - Standard_NC24s_v3 + - Standard_NC24rs_v3 + - Standard_NC24ads_A100_v4 + - Standard_NC48ads_A100_v4 + - Standard_NC96ads_A100_v4 + - Standard_NC64as_T4_v3 + - Standard_ND40rs_v2 + - Standard_ND96asr_v4 + - Standard_ND96amsr_A100_v4 + default: Standard_NC6s_v3 + description: Compute instance type to deploy model. Make sure that instance type is available and have enough quota available. + + instance_count: + type: integer + optional: true + default: 1 + description: Number of instances you want to use for deployment. Make sure instance type have enough quota available. + + model_id: + type: string + optional: false + description: | + Asset ID of the model registered in workspace/registry. + Registry - azureml://registries//models//versions/ + Workspace - azureml:: + + environment_id: + type: string + optional: false + description: | + Asset ID of the environment registered in workspace/registry. + Registry - azureml://registries//environments//versions/ + Workspace - azureml:: + + model_name: + type: string + optional: false + description: Name of the model to validate. + + model_version: + type: integer + optional: false + description: Model onboarding version (e.g., 5) + + publisher_name: + type: string + optional: false + description: Name of the model publisher (e.g., ContosoAI) + + selfserve_base_url: + type: string + optional: true + default: "https://int.api.azureml-test.ms" + description: Base URL of the model publisher self-serve API + + sku: + type: string + optional: true + default: "Standard_NC24ads_A100_v4" + description: SKU of the deployed model endpoint. + + inference_payload: + type: string + optional: false + description: JSON payload which would be used to validate deployment + + endpoint_name: + type: string + optional: true + description: Name of the endpoint + + deployment_name: + type: string + optional: true + default: default + description: Name of the deployment + + validation_id: + type: string + optional: true + description: ID of the validation run (used for updating status in self-serve) + + inference_response: + type: string + optional: true + description: JSON file containing the expected inference response. + +# Pipeline outputs +outputs: + validation_results: + description: Output file containing the validation results. + type: uri_folder + +jobs: + online_deployment_model: + type: command + component: azureml:deploy_inference_model:0.0.1 + compute: ${{parent.inputs.compute}} + inputs: + model_id: ${{parent.inputs.model_id}} + environment_id: ${{parent.inputs.environment_id}} + inference_payload_str: ${{parent.inputs.inference_payload}} + endpoint_name: ${{parent.inputs.endpoint_name}} + deployment_name: ${{parent.inputs.deployment_name}} + instance_type: ${{parent.inputs.instance_type}} + instance_count: ${{parent.inputs.instance_count}} + outputs: + model_deployment_details: + type: uri_file + model_inference_response: + type: uri_file + deploy_error: + type: uri_file + + run_inference_validation: + type: command + component: azureml:run_inference_validation:0.0.1 + inputs: + validation_id: ${{parent.inputs.validation_id}} + sku: ${{parent.inputs.instance_type}} + inference_payload: ${{parent.inputs.inference_payload}} + expected_response: ${{parent.inputs.inference_response}} + inference_response: ${{parent.jobs.online_deployment_model.outputs.model_inference_response}} + deployment_error: ${{parent.jobs.online_deployment_model.outputs.deploy_error}} + outputs: + validation_results: ${{parent.outputs.validation_results}} + validation_error: + type: uri_file + + delete_endpoints: + type: command + component: azureml:delete_endpoint:0.0.7 + inputs: + model_deployment_details: ${{parent.jobs.online_deployment_model.outputs.model_deployment_details}} + endpoint_name: ${{parent.inputs.endpoint_name}} + + publish_results: + type: command + component: azureml:publish_validation_results_selfserve:0.0.1 + inputs: + publisher_name: ${{parent.inputs.publisher_name}} + model_name: ${{parent.inputs.model_name}} + model_version: ${{parent.inputs.model_version}} + sku: ${{parent.inputs.instance_type}} + validation_id: ${{parent.inputs.validation_id}} + selfserve_base_url: ${{parent.inputs.selfserve_base_url}} + metrics_storage_uri: ${{parent.jobs.run_inference_validation.outputs.metrics_storage_uri}} + validation_error: ${{parent.jobs.run_inference_validation.outputs.validation_error}} diff --git a/assets/training/model_management/src/azureml/model/mgmt/config.py b/assets/training/model_management/src/azureml/model/mgmt/config.py index 2cd6b5f5d4..114e1dcc03 100644 --- a/assets/training/model_management/src/azureml/model/mgmt/config.py +++ b/assets/training/model_management/src/azureml/model/mgmt/config.py @@ -47,6 +47,8 @@ class AppName: DOWNLOAD_MODEL = "download_model" CONVERT_MODEL_TO_MLFLOW = "convert_model_to_mlflow" VALIDATION_TRIGGER_IMPORT = "validation_trigger_import" + RUN_INFERENCE_VALIDATION = "run_inference_validation" + PUBHLISH_VALIDATION_RESULTS_SELF_SERVE = "publish_validation_results_self_serve" class LoggerConfig: diff --git a/assets/training/model_management/src/publish_validation_results_selfserve.py b/assets/training/model_management/src/publish_validation_results_selfserve.py new file mode 100644 index 0000000000..e6fa644f6e --- /dev/null +++ b/assets/training/model_management/src/publish_validation_results_selfserve.py @@ -0,0 +1,190 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Update model onboarding version with CommonBench results.""" + +import sys +import os +import json +import requests +import argparse +from datetime import datetime, timezone +from azure.identity import ManagedIdentityCredential +from azure.ai.ml.identity import AzureMLOnBehalfOfCredential +from azureml.model.mgmt.config import AppName +from azureml.model.mgmt.utils.logging_utils import custom_dimensions, get_logger + + +logger = get_logger(__name__) +custom_dimensions.app_name = AppName.PUBHLISH_VALIDATION_RESULTS_SELF_SERVE + + +def read_results_from_file(file_path): + """Read the metrics results from the given file path.""" + try: + with open(file_path, 'r') as f: + results_dict = json.load(f) + print(f"Results loaded from {file_path}") + return results_dict + except Exception as e: + print(f"Error reading from file: {e}") + return None + + +def get_auth_token(): + """Generate auth token for Azure API.""" + is_obo = False + tokenUri = "https://management.azure.com/.default" + token = None + + try: + credential = AzureMLOnBehalfOfCredential() + token = credential.get_token(tokenUri).token + is_obo = True + except Exception: + logger.warning( + "Failed to get user credentials, fetching MSI credentials") + + if not is_obo: + try: + msi_client_id = os.environ.get("DEFAULT_IDENTITY_CLIENT_ID") + credential = ManagedIdentityCredential(client_id=msi_client_id) + token = credential.get_token(tokenUri).token + except Exception as ex: + raise Exception(f"Failed to get MSI credentials : {ex}") + + return token + + +def update_model_onboarding_version( + publisher_name, + model_name, + model_version, + sku, + validation_id, + selfserve_base_url, + metrics_storage_uri, + error_message +): + """Update model onboarding version with benchmark results.""" + current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + if not metrics_storage_uri: + validation_success = False + metrics_url = None + else: + metrics_path_dict = read_results_from_file(metrics_storage_uri) + metrics_url = metrics_path_dict.get( + "api_inference_path") if metrics_path_dict else None + validation_success = metrics_url is not None + + validation_result = [] + logger.info(f"validation_success: {validation_success}, metrics_url: {metrics_url}, metrics_storage_uri: {metrics_storage_uri}") + + if validation_id: + validation_result.append({ + "Id": validation_id, + "type": "API_VALIDATION", + "passed": True, + "message": "API inference passed successfully", + "validationResultUrl": metrics_url, + "errorMessage": error_message if error_message else None, + "status": "Completed" if validation_success else "Failed", + "createdTime": current_time, + "updatedTime": current_time, + "sku": sku + }) + else: + logger.error( + "Validation ID is None, not updating validation results in self-serve") + sys.exit(1) + + payload = { + "passed": True, + "status": "Completed", + "message": "Validation Successful", + "validationResult": validation_result + } + + api_url = ( + f"{selfserve_base_url}/model-publisher-self-serve/publishers/{publisher_name}/models/{model_name}" + f"/model-onboarding-version/{model_version}/updateModelOnboardingVersion?api-version=2024-12-31" + ) + + headers = { + "Authorization": f"Bearer {get_auth_token()}", + "Content-Type": "application/json", + "User-Agent": "AzureML-ModelPublishing/1.0" + } + + try: + logger.info(f"Sending request to {api_url} \n, headers: {headers} \n, payload: {payload}") + + response = requests.put(api_url, headers=headers, json=payload) + + logger.info(f"Response: {response.text}") + + if response.ok: + logger.info( + f"Successfully updated model onboarding version. Response: {response.status_code}") + return {"status_code": response.status_code} + else: + logger.error( + f"Failed to update model onboarding version. Status code: {response.status_code}") + logger.error(f"Response content: {response.text}") + raise Exception( + f"Request failed with status code {response.status_code}: {response.text}") + except requests.RequestException as e: + logger.error(f"Request failed: {e}") + raise + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Update model onboarding version with CommonBench validation results") + + parser.add_argument("--publisher-name", required=True, + help="Name of the model publisher (e.g., ContosoAI)") + parser.add_argument("--model-name", required=True, + help="Name of the model (e.g., VerboGenie)") + parser.add_argument("--model-version", required=True, + help="Model onboarding version (e.g., 5)") + parser.add_argument("--selfserve-base-url", required=True, + default="https://int.api.azureml-test.ms", + help="Base URL of the model publisher self-serve API") + parser.add_argument("--validation-id", required=True, + help="Run ID of the validation run") + parser.add_argument("--metrics-storage-uri", required=False, + help="URI to the storage where validation metrics are stored") + parser.add_argument("--sku", required=False, + default="Standard_NC24ads_A100_v4", + help="Suggested SKU based on benchmark results") + parser.add_argument("--validation-error", required=False, + help="Path to the file containing validation error messages or stack traces") + + args = parser.parse_args() + logger.info(f"Arguments: {args}") + + error_message = "" + if args.validation_error: + try: + with open(args.validation_error, "r") as f: + validation_error_message = f.read().strip() + error_message += f"Validation Error: {validation_error_message}\n" + except Exception as e: + logger.warning(f"Failed to read validation_error file: {e}") + + try: + result = update_model_onboarding_version( + args.publisher_name, + args.model_name, + args.model_version, + args.sku, + args.validation_id, + args.selfserve_base_url, + args.metrics_storage_uri, + error_message + ) + logger.info("Model onboarding version update completed successfully") + except Exception as e: + logger.error(f"Failed to update model onboarding version: {e}") + sys.exit(1) diff --git a/assets/training/model_management/src/run_inference_validation.py b/assets/training/model_management/src/run_inference_validation.py new file mode 100644 index 0000000000..afc9df212f --- /dev/null +++ b/assets/training/model_management/src/run_inference_validation.py @@ -0,0 +1,422 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Validate the structure of expected and actual inference response JSON files.""" + +import base64 +import json +import argparse +import os +import sys +import traceback +import re +from datetime import datetime, timezone +from azureml.core import Run +from azureml.model.mgmt.utils.common_utils import get_mlclient +from azureml.model.mgmt.config import AppName +from azureml.model.mgmt.utils.logging_utils import custom_dimensions, get_logger + + +logger = get_logger(__name__) +custom_dimensions.app_name = AppName.RUN_INFERENCE_VALIDATION + + +def load_json(file_path): + """Load JSON data from a file. If the loaded data is a string, try to parse it as JSON.""" + try: + with open(file_path, "r") as f: + data = json.load(f) + # If data is a string, parse it as JSON. + if isinstance(data, str): + try: + data = json.loads(data) + except Exception as e: + logger.warning(f"Error parsing JSON from string in {file_path}: {e}") + return data + except Exception as e: + logger.warning(f"Error loading JSON file {file_path}: {e}") + return None + + +def load_json_from_string(json_string): + """Load JSON data from a string.""" + try: + data = json.loads(json_string) + return data + except Exception as e: + logger.warning(f"Error parsing JSON from string: {e}") + return None + + +def set_nested_value(d, keys, value): + """ + Helper to set a value into a nested dictionary/list from a list of keys/indexes. + """ + for i, key in enumerate(keys): + is_last = i == len(keys) - 1 + if isinstance(key, int): + while len(d) <= key: + d.append({} if not is_last else None) + if is_last: + d[key] = value + else: + if not isinstance(d[key], (dict, list)): + d[key] = {} + d = d[key] + else: + if key not in d or not isinstance(d[key], (dict, list)): + d[key] = {} if not is_last else None + if is_last: + d[key] = value + else: + d = d[key] + +def parse_key_path(key): + """ + Converts a key string like '[0].a.b[1]' to a list of keys: [0, 'a', 'b', 1] + """ + parts = re.findall(r'\[(\d+)\]|([^.]+)', key) + return [int(i) if i else j for i, j in parts] + +def build_nested_json(flat_dict): + """ + Converts a flat key-path dictionary to nested JSON. + """ + result = {} if flat_dict else None + for key_path, value in flat_dict.items(): + keys = parse_key_path(key_path) + if isinstance(keys[0], int): + if not isinstance(result, list): + result = [] + set_nested_value(result, keys, value) + return result + +def get_json_structure_with_values(data, parent_key=''): + """ + Recursively extract key paths and their values from nested JSON. + Returns a dictionary of full_key_path: value + """ + items = {} + if isinstance(data, dict): + for k, v in data.items(): + full_key = f"{parent_key}.{k}" if parent_key else k + if isinstance(v, (dict, list)): + items.update(get_json_structure_with_values(v, full_key)) + else: + items[full_key] = v + elif isinstance(data, list): + for index, item in enumerate(data): + full_key = f"{parent_key}[{index}]" if parent_key else f"[{index}]" + if isinstance(item, (dict, list)): + items.update(get_json_structure_with_values(item, full_key)) + else: + items[full_key] = item + return items + +def compare_structures(expected_response, actual_response): + """ + Compare JSON structures and return full nested added/removed diffs. + """ + expected_structure = get_json_structure_with_values(expected_response) + actual_structure = get_json_structure_with_values(actual_response) + + logger.info(f"Expected flat structure: {expected_structure}") + logger.info(f"Actual flat structure: {actual_structure}") + + added_keys = actual_structure.keys() - expected_structure.keys() + removed_keys = expected_structure.keys() - actual_structure.keys() + + added_flat = {key: actual_structure[key] for key in added_keys} + removed_flat = {key: expected_structure[key] for key in removed_keys} + + added_nested = build_nested_json(added_flat) + removed_nested = build_nested_json(removed_flat) + + structure_match = not added_flat and not removed_flat + + result = { + "structure_match": structure_match, + "structural_difference": { + "added": added_nested, + "removed": removed_nested + } + } + + logger.info("Comparison result:") + logger.info(json.dumps(result, indent=4)) + + return result + +def save_validation_result(request_details, output_dir, validation_id, sku, status): + """Save validation results to a JSON file.""" + try: + logger.info(f"Saving validation result to {output_dir}") + # Create the output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Output directory: {output_dir}") + output_path = os.path.join(output_dir, "validation_result.json") + logger.info(f"Output path: {output_path}") + + current_time = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + validation_result = { + "id": validation_id, + "sku": sku, + "createdTime": current_time, + "updatedTime": current_time, + "type": "MAAP_INFERENCING", + "status": status, + "requestDetails": request_details + } + + with open(output_path, "w") as f: + json.dump(validation_result, f, indent=4) + logger.info(f"Validation result saved to {output_path}") + except Exception as e: + logger.error(f"Error saving validation result: {e}") + raise Exception(f"Failed to get MSI credentials : {e}") + + +def replace_name_in_path(path_template, name_value): + """Replace the placeholder in the output path with the actual job name.""" + return path_template.replace('${{name}}', name_value) + + +def fetch_storage_uri(): + """Return the storage URI of the output file from the AzureML pipeline run.""" + try: + run = Run.get_context() + run_details = run.get_details() + output_data = run_details['runDefinition']['outputData']['validation_results']['outputLocation']['uri'] + output_data_path = output_data['path'] + + output_data_uri = replace_name_in_path(output_data_path, run.id) + # Extract datastore name and path from the AzureML URI + datastore_name, path = extract_datastore_info(output_data_uri) + + # Construct the storage URI + storage_uri = get_storage_url(datastore_name) + folder_uri = f"{storage_uri}/{path}" + # Construct the full path to the validation_result.json file + full_file_uri = f"{folder_uri}/validation_result.json" + + logger.info(f"Full storage URI (file): {full_file_uri}") + + return full_file_uri # This is the full path to validation_result.json + except Exception as e: + logger.error(f"Error fetching storage URI: {e}") + return None + + +def store_metrics_paths(metrics_file_path): + """Store the paths of the metrics CSV files in a JSON file.""" + base_path = fetch_storage_uri() + + logger.info(f"validation_result_path: {base_path}") + result_dict = {} + result_dict['api_inference_path'] = base_path + if result_dict: + write_results_to_file(result_dict, metrics_file_path) + + +def fetch_path(output_dir): + """Return the relative path of the data from the output directory.""" + try: + # Calculate relative path from the job folder + rel_path = os.path.relpath(output_dir, os.getcwd()) + logger.info(f"api inference validation relative path: {rel_path}") + result_dict = { + 'api_inference_path': rel_path + } + return result_dict + except Exception as e: + logger.error(f"Error calculating relative path: {e}") + return {} + + +def write_results_to_file(results_dict, file_path): + """Write the results dictionary to a JSON file.""" + try: + with open(file_path, 'w') as f: + json.dump(results_dict, f, indent=4) + logger.info(f"Results written to {file_path} in JSON format") + return True + except Exception as e: + logger.error(f"Error writing to file: {e}") + return False + + +def get_storage_url(datastore_name): + """Retrieve the storage URL for the specified datastore.""" + # Get MLClient instance + ml_client = get_mlclient() + datastore = ml_client.datastores.get(datastore_name) + storage_account_name = datastore.account_name + container_name = datastore.container_name + endpoint = datastore.endpoint + + storage_uri = f"https://{storage_account_name}.blob.{endpoint}/{container_name}" + logger.info(f"validation result storage: {storage_uri}") + + return storage_uri + + +def extract_datastore_info(datastore_uri_path): + """Extract both datastore name and path from an Azure ML datastore URI path.""" + # Check if it's a valid datastore URI + if not datastore_uri_path.startswith('azureml://datastores/'): + return None, None + + parts = datastore_uri_path.split('/') + + # The datastore name should be the part after 'datastores/' + if len(parts) >= 5 and parts[0] == 'azureml:' and parts[1] == '' and parts[2] == 'datastores' and 'paths' in parts: + datastore_name = parts[3] + + # Find the index of 'paths' in the URI + paths_index = parts.index('paths') + + # Join everything after 'paths/' to form the path + path = '/'.join(parts[(paths_index + 1):]) + + return datastore_name, path + + return None, None + + +def run_inference_validation(): + """Perform the inference validation logic.""" + try: + args = parse_args() + error_message = "" + if args.deployment_error: + try: + with open(args.deployment_error, "r") as f: + deployment_error = f.read().strip() + error_message += deployment_error + except Exception as e: + logger.warning(f"Failed to read deployment_error file: {e}") + + if args.validation_error: + with open(args.validation_error, "w") as error_file: + error_file.write(error_message) + inference_payload = None + if args.inference_payload: + decoded_bytes = base64.b64decode(args.inference_payload) + + # Convert bytes to string + decoded_str = decoded_bytes.decode('utf-8') + logger.info(f"Decoded string: {decoded_str}") + + inference_payload = json.loads(decoded_str) + + expected_response = None + if args.expected_response: + decoded_bytes = base64.b64decode(args.expected_response) + + # Convert bytes to string + decoded_str = decoded_bytes.decode('utf-8') + logger.info(f"Decoded string: {decoded_str}") + expected_response = json.loads(decoded_str) + + inference_output = None + if args.inference_response: + inference_output = load_json(args.inference_response) + if not inference_output: + logger.error("Inference response is missing or invalid.") + + inference_response = None + if inference_output: + inference_response = inference_output.get("response") + if isinstance(inference_response, str): + try: + inference_response = json.loads(inference_response) + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse actual response as JSON: {e}") + + if inference_response is None: + logger.warning("Actual response is missing or invalid. Setting it to an empty structure.") + inference_response = {} + + inference_time = inference_output.get("inference_time", 0) if inference_output else 0 + logger.info(f"inference_payload: {inference_payload}, expected response: {expected_response}, " + f"actual response: {inference_response}") + + # Infer success status based on the presence of a valid response + success_status = inference_response is not None and bool(inference_response) + status = "Success" if success_status else "Failed" + + request_details = { + "providedRequest": inference_payload, + "providedResponse": expected_response, + "actualResponse": inference_response, + "responseTimeMs": inference_time, + "errorMessage": error_message, + "structuralDiff": None, + } + logger.info(f"Request details: {request_details}") + if expected_response and inference_response: + comparison_result = compare_structures(expected_response, inference_response) + request_details["structuralDiff"] = comparison_result.get("structural_difference", []) + + # Save the validation result. + save_validation_result(request_details, args.validation_results, args.validation_id, args.sku, status) + logger.info(f"validation_result: {request_details}, Validation result saved to {args.validation_results}") + + store_metrics_paths(args.metrics_storage_uri) + except Exception as e: + stack_trace = traceback.format_exc() + error_message = f"Model validation failed.\n{stack_trace}" + logger.error(error_message) + # Save the error message in the request details + request_details = { + "providedRequest": None, + "providedResponse": None, + "actualResponse": None, + "responseTimeMs": 0, + "errorMessage": error_message, + "structuralDiff": None, + } + + # Save the validation result with the error message + save_validation_result(request_details, args.validation_results, args.validation_id, args.sku, "Failed") + store_metrics_paths(args.metrics_storage_uri) + + # Write the error message to the specified error output file + if args.validation_error: + with open(args.validation_error, "w") as error_file: + error_file.write(error_message) + +def main(): + run_inference_validation() + + +def parse_args(): + """Compare expected and actual inference response structures.""" + parser = argparse.ArgumentParser() + parser.add_argument("--inference_payload", type=str, required=True, + help="Serialized JSON payload for inference") + parser.add_argument("--expected_response", type=str, required=False, + help="Path to the expected inference response JSON file.") + parser.add_argument("--inference_response", type=str, required=False, + help="Path to the actual inference response JSON file.") + parser.add_argument("--deployment_error", type=str, required=False, + help="Path to the deployment_error.") + parser.add_argument("--validation_results", type=str, required=True, + help="Path to save validation results.") + parser.add_argument("--metrics_storage_uri", type=str, required=True, + help="Path to store the metrics.") + parser.add_argument("--sku", required=False, + default="Standard_NC24ads_A100_v4", + help="Suggested SKU based on benchmark results") + parser.add_argument("--validation-id", required=True, + help="Run ID of the validation run") + parser.add_argument("--validation_error", type=str, required=False, + help="Path to the file where error messages or stack traces will be written.") + + args = parser.parse_args() + logger.info(f"Arguments: {args}") + return args + + +if __name__ == "__main__": + main()