diff --git a/src/model_api/adapters/ovms_adapter.py b/src/model_api/adapters/ovms_adapter.py index df0581a8..52b99363 100644 --- a/src/model_api/adapters/ovms_adapter.py +++ b/src/model_api/adapters/ovms_adapter.py @@ -1,12 +1,12 @@ # -# Copyright (C) 2020-2024 Intel Corporation +# Copyright (C) 2020-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # from __future__ import annotations import re -from typing import Any, Callable +from typing import Any, Callable, Final import numpy as np @@ -17,6 +17,26 @@ class OVMSAdapter(InferenceAdapter): """Inference adapter that allows working with models served by the OpenVINO Model Server""" + # OVMS model URL regex pattern + # Expected format:
:/v2/models/[/versions/] + _OVMS_MODEL_URL_PATTERN: Final = re.compile( + r"(.+)\/v2\/models\/([^\/]+)(?:(?:\/versions\/)(\d+))?(?:\/)?", + ) + + # Triton to NumPy precision mapping + _triton2np_precision: Final = { + "INT64": np.int64, + "UINT64": np.uint64, + "FLOAT": np.float32, + "UINT32": np.uint32, + "INT32": np.int32, + "HALF": np.float16, + "INT16": np.int16, + "INT8": np.int8, + "UINT8": np.uint8, + "FP32": np.float32, + } + def __init__(self, target_model: str): """ Initializes OVMS adapter. @@ -26,7 +46,7 @@ def __init__(self, target_model: str): """ import tritonclient.http as httpclient - service_url, self.model_name, self.model_version = _parse_model_arg( + service_url, self.model_name, self.model_version = self.parse_model_arg( target_model, ) self.client = httpclient.InferenceServerClient(service_url) @@ -83,7 +103,7 @@ def infer_sync(self, dict_data: dict) -> dict: Returns: dict: model raw outputs. """ - inputs = _prepare_inputs(dict_data, self.inputs) + inputs = self._prepare_inputs(dict_data) raw_result = self.client.infer( model_name=self.model_name, model_version=self.model_version, @@ -98,7 +118,7 @@ def infer_sync(self, dict_data: dict) -> dict: def infer_async(self, dict_data: dict, callback_data: Any): """A stub method imitating async inference with a blocking call.""" - inputs = _prepare_inputs(dict_data, self.inputs) + inputs = self._prepare_inputs(dict_data) raw_result = self.client.infer( model_name=self.model_name, model_version=self.model_version, @@ -166,65 +186,49 @@ def save_model(self, path: str, weights_path: str | None = None, version: str | msg = "OVMSAdapter does not support saving a model" raise NotImplementedError(msg) + @staticmethod + def is_ovms_model(target_model: str) -> bool: + """Checks if the given string is a valid OVMS model URL.""" + if not isinstance(target_model, str): + return False -_triton2np_precision = { - "INT64": np.int64, - "UINT64": np.uint64, - "FLOAT": np.float32, - "UINT32": np.uint32, - "INT32": np.int32, - "HALF": np.float16, - "INT16": np.int16, - "INT8": np.int8, - "UINT8": np.uint8, - "FP32": np.float32, -} - - -def _parse_model_arg(target_model: str): - """Parses OVMS model URL.""" - if not isinstance(target_model, str): - msg = "target_model must be str" - raise TypeError(msg) - # Expected format:
:/models/[:] - if not re.fullmatch( - r"(\w+\.*\-*)*\w+:\d+\/v2/models\/[a-zA-Z0-9._-]+(\:\d+)*", - target_model, - ): - msg = "invalid --model option format" - raise ValueError(msg) - service_url, _, _, model = target_model.split("/") - model_spec = model.split(":") - if len(model_spec) == 1: - # model version not specified - use latest - return service_url, model_spec[0], "" - if len(model_spec) == 2: - return service_url, model_spec[0], model_spec[1] - msg = "Invalid target_model format" - raise ValueError(msg) - - -def _prepare_inputs(dict_data: dict, inputs_meta: dict[str, Metadata]): - """Converts raw model inputs into OVMS-specific representation.""" - import tritonclient.http as httpclient - - inputs = [] - for input_name, input_data in dict_data.items(): - if input_name not in inputs_meta: - msg = "Input data does not match model inputs" + return OVMSAdapter._OVMS_MODEL_URL_PATTERN.fullmatch(target_model) is not None + + @staticmethod + def parse_model_arg(target_model: str): + """Parses OVMS model URL.""" + if not isinstance(target_model, str): + msg = "target_model must be str" + raise TypeError(msg) + + match = OVMSAdapter._OVMS_MODEL_URL_PATTERN.fullmatch(target_model) + if not match: + msg = "invalid --model option format" raise ValueError(msg) - input_info = inputs_meta[input_name] - model_precision = _triton2np_precision[input_info.precision] - if isinstance(input_data, np.ndarray) and input_data.dtype != model_precision: - input_data = input_data.astype(model_precision) - elif isinstance(input_data, list): - input_data = np.array(input_data, dtype=model_precision) - - infer_input = httpclient.InferInput( - input_name, - input_data.shape, - input_info.precision, - ) - infer_input.set_data_from_numpy(input_data) - inputs.append(infer_input) - return inputs + + return match.group(1), match.group(2), match.group(3) or "" + + def _prepare_inputs(self, dict_data: dict) -> list: + """Converts raw model inputs into OVMS-specific representation.""" + import tritonclient.http as httpclient + + inputs = [] + for input_name, input_data in dict_data.items(): + if input_name not in self.inputs: + msg = "Input data does not match model inputs" + raise ValueError(msg) + input_info = self.inputs[input_name] + model_precision = self._triton2np_precision[input_info.precision] + if isinstance(input_data, np.ndarray) and input_data.dtype != model_precision: + input_data = input_data.astype(model_precision) + elif isinstance(input_data, list): + input_data = np.array(input_data, dtype=model_precision) + + infer_input = httpclient.InferInput( + input_name, + input_data.shape, + input_info.precision, + ) + infer_input.set_data_from_numpy(input_data) + inputs.append(infer_input) + return inputs diff --git a/src/model_api/models/model.py b/src/model_api/models/model.py index ccabdc87..47168dac 100644 --- a/src/model_api/models/model.py +++ b/src/model_api/models/model.py @@ -6,7 +6,6 @@ from __future__ import annotations import logging as log -import re from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, NoReturn, Type @@ -173,9 +172,7 @@ def create_model( inference_adapter: InferenceAdapter if isinstance(model, InferenceAdapter): inference_adapter = model - elif isinstance(model, str) and re.compile( - r"(\w+\.*\-*)*\w+:\d+\/v2/models\/[a-zA-Z0-9._-]+(\:\d+)*", - ).fullmatch(model): + elif isinstance(model, str) and OVMSAdapter.is_ovms_model(model): inference_adapter = OVMSAdapter(model) else: if core is None: diff --git a/tests/unit/adapters/test_ovms_adapter.py b/tests/unit/adapters/test_ovms_adapter.py new file mode 100644 index 00000000..fe308fc8 --- /dev/null +++ b/tests/unit/adapters/test_ovms_adapter.py @@ -0,0 +1,76 @@ +# +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import pytest + +from model_api.adapters.ovms_adapter import OVMSAdapter + + +class TestParseModelArg: + """Test cases for the parse_model_arg function.""" + + def test_valid_url_with_version(self): + """Test parsing a valid URL with version specified.""" + target_model = "http://localhost:9000/v2/models/my_model/versions/123" + service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model) + + assert service_url == "http://localhost:9000" + assert model_name == "my_model" + assert version == "123" + + def test_valid_url_without_version(self): + """Test parsing a valid URL without version specified.""" + target_model = "http://localhost:9000/v2/models/345$%^!@#$model" + service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model) + + assert service_url == "http://localhost:9000" + assert model_name == "345$%^!@#$model" + assert version == "" + + def test_valid_url_with_trailing_slash(self): + """Test parsing a valid URL with trailing slash.""" + target_model = "http://localhost:9000/v2/models/my_model/" + service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model) + + assert service_url == "http://localhost:9000" + assert model_name == "my_model" + assert version == "" + + def test_valid_url_with_version_and_trailing_slash(self): + """Test parsing a valid URL with version and trailing slash.""" + target_model = "http://localhost:9000/v2/models/my_model/versions/456/" + service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model) + + assert service_url == "http://localhost:9000" + assert model_name == "my_model" + assert version == "456" + + def test_valid_url_https(self): + """Test parsing a valid HTTPS URL.""" + target_model = "https://example.com:8080/v2/models/test_model/versions/1" + service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model) + + assert service_url == "https://example.com:8080" + assert model_name == "test_model" + assert version == "1" + + @pytest.mark.parametrize( + ("target_model", "description"), + [ + ("http://localhost:9000/models/my_model", "missing v2/models path"), + ("http://localhost:9000/v2/models/my_model/version/123", "wrong versions format"), + ("http://localhost:9000/v2/models//versions/123", "empty model name"), + ("http://localhost:9000/v2/models/", "no model name"), + ("http://localhost:9000/v2", "incomplete URL"), + ("http://localhost:9000/v2/models/my_model/versions/latest", "non-numeric version"), + ("http://localhost:9000/v2/models/my_model/extra/path", "extra path"), + ("http://localhost:9000/v2/models/my_model/versions/", "no version specified"), + ("", "empty"), + ], + ) + def test_invalid_url_formats(self, target_model, description): + """Test parsing various invalid URL formats.""" + with pytest.raises(ValueError, match="invalid --model option format"): + OVMSAdapter.parse_model_arg(target_model) diff --git a/tests/unit/test_utils.py b/tests/unit/adapters/test_utils.py similarity index 100% rename from tests/unit/test_utils.py rename to tests/unit/adapters/test_utils.py diff --git a/tests/unit/test_types.py b/tests/unit/models/test_types.py similarity index 100% rename from tests/unit/test_types.py rename to tests/unit/models/test_types.py