diff --git a/assets/training/model_management/components/convert_model_to_mlflow/spec.yaml b/assets/training/model_management/components/convert_model_to_mlflow/spec.yaml index 4a23c598f2..3ce408b29e 100644 --- a/assets/training/model_management/components/convert_model_to_mlflow/spec.yaml +++ b/assets/training/model_management/components/convert_model_to_mlflow/spec.yaml @@ -13,6 +13,7 @@ environment: azureml://registries/azureml/environments/model-management/versions code: ../../src/ command: | + # pip install transformers==4.48.0 torch==2.2.2 numpy==1.23.5 pandas==1.5.3 urllib3==1.26.19 --no-cache-dir --force-reinstall # TODO: This has one disadvantage as shell logs wont be pushed to appinsights set -ex IFS=',' read -ra pip_pkgs <<< "$[[${{inputs.extra_pip_requirements}}]]" diff --git a/assets/training/model_management/components/validation_trigger_import/spec.yaml b/assets/training/model_management/components/validation_trigger_import/spec.yaml index a21f5b8d4c..2fcd5f71a0 100644 --- a/assets/training/model_management/components/validation_trigger_import/spec.yaml +++ b/assets/training/model_management/components/validation_trigger_import/spec.yaml @@ -181,7 +181,7 @@ outputs: is_deterministic: True -environment: azureml://registries/azureml/environments/python-sdk-v2/versions/25 +environment: azureml://registries/azureml/environments/python-sdk-v2/versions/29 code: ../../src command: python run_model_validate.py --validation-info ${{outputs.validation_info}} diff --git a/assets/training/model_management/environments/model-management/context/Dockerfile b/assets/training/model_management/environments/model-management/context/Dockerfile index 9a3d6bda04..e7f68ed9bb 100644 --- a/assets/training/model_management/environments/model-management/context/Dockerfile +++ b/assets/training/model_management/environments/model-management/context/Dockerfile @@ -26,6 +26,7 @@ RUN pip install -r requirements.txt --no-cache-dir # Vulnerability fix RUN pip install Pillow gunicorn onnx==1.17.0 idna tqdm requests==2.32.1 tornado==6.4.2 certifi==2024.07.04 urllib3==1.26.19 scikit-learn==1.5.1 mlflow==2.20.3 mlflow-skinny==2.20.3 marshmallow==3.23.2 +transformers==4.48.0 torch==2.2.2 numpy==1.23.5 pandas==1.5.3 azureml-evaluate-mlflow --no-cache-dir --force-reinstall # List pip packages RUN pip list diff --git a/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/convertors.py b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/convertors.py index 0fcc355498..f35ce01ef0 100644 --- a/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/convertors.py +++ b/assets/training/model_management/src/azureml/model/mgmt/processors/pyfunc/convertors.py @@ -45,7 +45,7 @@ MLflowSchemaLiterals as VirchowMLFlowSchemaLiterals, MLflowLiterals as VirchowMLflowLiterals from azureml.model.mgmt.processors.pyfunc.hibou_b.config import \ MLflowSchemaLiterals as HibouBMLFlowSchemaLiterals, MLflowLiterals as HibouBMLflowLiterals - +from azureml.model.mgmt.utils.common_utils import get_mlclient logger = get_logger(__name__) @@ -114,6 +114,12 @@ def _save( :type metadata: Optional[Dict]. Defaults to {}. """ signatures = self._signatures or self.get_model_signature() + if not self._vllm_enabled: + mlclient = get_mlclient("azureml") + mlFlow_image = mlclient.environments.get("mlflow-model-inference", label="latest") + metadata["azureml.base_image"] = "mcr.microsoft.com/azureml/curated/mlflow-model-inference:" \ + + str(mlFlow_image.version) + logger.info("Metadata: {}".format(metadata)) # set metadata info metadata.update(fetch_mlflow_acft_metadata( base_model_name=self._model_id, diff --git a/assets/training/model_management/src/azureml/model/mgmt/processors/transformers/convertors.py b/assets/training/model_management/src/azureml/model/mgmt/processors/transformers/convertors.py index b650f13555..3667560e11 100644 --- a/assets/training/model_management/src/azureml/model/mgmt/processors/transformers/convertors.py +++ b/assets/training/model_management/src/azureml/model/mgmt/processors/transformers/convertors.py @@ -183,6 +183,12 @@ def _save( metadata["azureml.base_image"] = "mcr.microsoft.com/azureml/curated/foundation-model-inference:" \ + str(vllm_image.version) logger.info("Metadata: {}".format(metadata)) + else: + mlclient = get_mlclient("azureml") + mlFlow_image = mlclient.environments.get("mlflow-model-inference", label="latest") + metadata["azureml.base_image"] = "mcr.microsoft.com/azureml/curated/mlflow-model-inference:" \ + + str(mlFlow_image.version) + logger.info("Metadata: {}".format(metadata)) if self._model_flavor == "OSS": try: