From 95ca09c6e19d993afae440bb5b16a1e9685e6ae2 Mon Sep 17 00:00:00 2001 From: Mariusz Gumowski Date: Tue, 26 Aug 2025 09:55:17 +0200 Subject: [PATCH 1/3] fix regex format for model_api --- src/model_api/adapters/ovms_adapter.py | 21 +++---- src/model_api/models/model.py | 2 +- tests/unit/adapters/test_ovms_adapter.py | 76 ++++++++++++++++++++++++ tests/unit/{ => adapters}/test_utils.py | 0 tests/unit/{ => models}/test_types.py | 0 5 files changed, 85 insertions(+), 14 deletions(-) create mode 100644 tests/unit/adapters/test_ovms_adapter.py rename tests/unit/{ => adapters}/test_utils.py (100%) rename tests/unit/{ => models}/test_types.py (100%) diff --git a/src/model_api/adapters/ovms_adapter.py b/src/model_api/adapters/ovms_adapter.py index df0581a8..28b60658 100644 --- a/src/model_api/adapters/ovms_adapter.py +++ b/src/model_api/adapters/ovms_adapter.py @@ -186,22 +186,17 @@ def _parse_model_arg(target_model: str): if not isinstance(target_model, str): msg = "target_model must be str" raise TypeError(msg) - # Expected format:
:/models/[:] - if not re.fullmatch( - r"(\w+\.*\-*)*\w+:\d+\/v2/models\/[a-zA-Z0-9._-]+(\:\d+)*", + + # Expected format:
:/v2/models/[/versions/] + match = re.fullmatch( + r"(.+)\/v2\/models\/([^\/]+)(?:(?:\/versions\/)(\d+))?(?:\/)?", target_model, - ): + ) + if not match: msg = "invalid --model option format" raise ValueError(msg) - service_url, _, _, model = target_model.split("/") - model_spec = model.split(":") - if len(model_spec) == 1: - # model version not specified - use latest - return service_url, model_spec[0], "" - if len(model_spec) == 2: - return service_url, model_spec[0], model_spec[1] - msg = "Invalid target_model format" - raise ValueError(msg) + + return match.group(1), match.group(2), match.group(3) or "" def _prepare_inputs(dict_data: dict, inputs_meta: dict[str, Metadata]): diff --git a/src/model_api/models/model.py b/src/model_api/models/model.py index ccabdc87..5dc4b1b1 100644 --- a/src/model_api/models/model.py +++ b/src/model_api/models/model.py @@ -174,7 +174,7 @@ def create_model( if isinstance(model, InferenceAdapter): inference_adapter = model elif isinstance(model, str) and re.compile( - r"(\w+\.*\-*)*\w+:\d+\/v2/models\/[a-zA-Z0-9._-]+(\:\d+)*", + r"(.+)\/v2\/models\/([^\/]+)(?:(?:\/versions\/)(\d+))?(?:\/)?", ).fullmatch(model): inference_adapter = OVMSAdapter(model) else: diff --git a/tests/unit/adapters/test_ovms_adapter.py b/tests/unit/adapters/test_ovms_adapter.py new file mode 100644 index 00000000..35acc5c4 --- /dev/null +++ b/tests/unit/adapters/test_ovms_adapter.py @@ -0,0 +1,76 @@ +# +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import pytest +from model_api.adapters.ovms_adapter import _parse_model_arg + + +class TestParseModelArg: + """Test cases for the _parse_model_arg function.""" + + def test_valid_url_with_version(self): + """Test parsing a valid URL with version specified.""" + target_model = "http://localhost:9000/v2/models/my_model/versions/123" + service_url, model_name, version = _parse_model_arg(target_model) + + assert service_url == "http://localhost:9000" + assert model_name == "my_model" + assert version == "123" + + def test_valid_url_without_version(self): + """Test parsing a valid URL without version specified.""" + target_model = "http://localhost:9000/v2/models/345$%^!@#$model" + service_url, model_name, version = _parse_model_arg(target_model) + + assert service_url == "http://localhost:9000" + assert model_name == "345$%^!@#$model" + assert version == "" + + def test_valid_url_with_trailing_slash(self): + """Test parsing a valid URL with trailing slash.""" + target_model = "http://localhost:9000/v2/models/my_model/" + service_url, model_name, version = _parse_model_arg(target_model) + + assert service_url == "http://localhost:9000" + assert model_name == "my_model" + assert version == "" + + def test_valid_url_with_version_and_trailing_slash(self): + """Test parsing a valid URL with version and trailing slash.""" + target_model = "http://localhost:9000/v2/models/my_model/versions/456/" + service_url, model_name, version = _parse_model_arg(target_model) + + assert service_url == "http://localhost:9000" + assert model_name == "my_model" + assert version == "456" + + def test_valid_url_https(self): + """Test parsing a valid HTTPS URL.""" + target_model = "https://example.com:8080/v2/models/test_model/versions/1" + service_url, model_name, version = _parse_model_arg(target_model) + + assert service_url == "https://example.com:8080" + assert model_name == "test_model" + assert version == "1" + + + @pytest.mark.parametrize( + "target_model,description", + [ + ("http://localhost:9000/models/my_model", "missing v2/models path"), + ("http://localhost:9000/v2/models/my_model/version/123", "wrong versions format"), + ("http://localhost:9000/v2/models//versions/123", "empty model name"), + ("http://localhost:9000/v2/models/", "no model name"), + ("http://localhost:9000/v2", "incomplete URL"), + ("http://localhost:9000/v2/models/my_model/versions/latest", "non-numeric version"), + ("http://localhost:9000/v2/models/my_model/extra/path", "extra path"), + ("http://localhost:9000/v2/models/my_model/versions/", "no version specified"), + ("", "empty"), + ], + ) + def test_invalid_url_formats(self, target_model, description): + """Test parsing various invalid URL formats.""" + with pytest.raises(ValueError, match="invalid --model option format"): + _parse_model_arg(target_model) diff --git a/tests/unit/test_utils.py b/tests/unit/adapters/test_utils.py similarity index 100% rename from tests/unit/test_utils.py rename to tests/unit/adapters/test_utils.py diff --git a/tests/unit/test_types.py b/tests/unit/models/test_types.py similarity index 100% rename from tests/unit/test_types.py rename to tests/unit/models/test_types.py From 4e42c666f837eaa1b93d383949ef9bbdea99e01b Mon Sep 17 00:00:00 2001 From: Mariusz Gumowski Date: Tue, 26 Aug 2025 09:57:27 +0200 Subject: [PATCH 2/3] fix style --- tests/unit/adapters/test_ovms_adapter.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit/adapters/test_ovms_adapter.py b/tests/unit/adapters/test_ovms_adapter.py index 35acc5c4..e6802666 100644 --- a/tests/unit/adapters/test_ovms_adapter.py +++ b/tests/unit/adapters/test_ovms_adapter.py @@ -4,6 +4,7 @@ # import pytest + from model_api.adapters.ovms_adapter import _parse_model_arg @@ -14,7 +15,7 @@ def test_valid_url_with_version(self): """Test parsing a valid URL with version specified.""" target_model = "http://localhost:9000/v2/models/my_model/versions/123" service_url, model_name, version = _parse_model_arg(target_model) - + assert service_url == "http://localhost:9000" assert model_name == "my_model" assert version == "123" @@ -23,7 +24,7 @@ def test_valid_url_without_version(self): """Test parsing a valid URL without version specified.""" target_model = "http://localhost:9000/v2/models/345$%^!@#$model" service_url, model_name, version = _parse_model_arg(target_model) - + assert service_url == "http://localhost:9000" assert model_name == "345$%^!@#$model" assert version == "" @@ -32,7 +33,7 @@ def test_valid_url_with_trailing_slash(self): """Test parsing a valid URL with trailing slash.""" target_model = "http://localhost:9000/v2/models/my_model/" service_url, model_name, version = _parse_model_arg(target_model) - + assert service_url == "http://localhost:9000" assert model_name == "my_model" assert version == "" @@ -41,7 +42,7 @@ def test_valid_url_with_version_and_trailing_slash(self): """Test parsing a valid URL with version and trailing slash.""" target_model = "http://localhost:9000/v2/models/my_model/versions/456/" service_url, model_name, version = _parse_model_arg(target_model) - + assert service_url == "http://localhost:9000" assert model_name == "my_model" assert version == "456" @@ -50,14 +51,13 @@ def test_valid_url_https(self): """Test parsing a valid HTTPS URL.""" target_model = "https://example.com:8080/v2/models/test_model/versions/1" service_url, model_name, version = _parse_model_arg(target_model) - + assert service_url == "https://example.com:8080" assert model_name == "test_model" assert version == "1" - @pytest.mark.parametrize( - "target_model,description", + ("target_model", "description"), [ ("http://localhost:9000/models/my_model", "missing v2/models path"), ("http://localhost:9000/v2/models/my_model/version/123", "wrong versions format"), From 966e04726046c89626e7a5e2803aa869974a1f7a Mon Sep 17 00:00:00 2001 From: Mariusz Gumowski Date: Tue, 26 Aug 2025 10:39:15 +0200 Subject: [PATCH 3/3] refactor --- src/model_api/adapters/ovms_adapter.py | 121 ++++++++++++----------- src/model_api/models/model.py | 5 +- tests/unit/adapters/test_ovms_adapter.py | 16 +-- 3 files changed, 74 insertions(+), 68 deletions(-) diff --git a/src/model_api/adapters/ovms_adapter.py b/src/model_api/adapters/ovms_adapter.py index 28b60658..52b99363 100644 --- a/src/model_api/adapters/ovms_adapter.py +++ b/src/model_api/adapters/ovms_adapter.py @@ -1,12 +1,12 @@ # -# Copyright (C) 2020-2024 Intel Corporation +# Copyright (C) 2020-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 # from __future__ import annotations import re -from typing import Any, Callable +from typing import Any, Callable, Final import numpy as np @@ -17,6 +17,26 @@ class OVMSAdapter(InferenceAdapter): """Inference adapter that allows working with models served by the OpenVINO Model Server""" + # OVMS model URL regex pattern + # Expected format:
:/v2/models/[/versions/] + _OVMS_MODEL_URL_PATTERN: Final = re.compile( + r"(.+)\/v2\/models\/([^\/]+)(?:(?:\/versions\/)(\d+))?(?:\/)?", + ) + + # Triton to NumPy precision mapping + _triton2np_precision: Final = { + "INT64": np.int64, + "UINT64": np.uint64, + "FLOAT": np.float32, + "UINT32": np.uint32, + "INT32": np.int32, + "HALF": np.float16, + "INT16": np.int16, + "INT8": np.int8, + "UINT8": np.uint8, + "FP32": np.float32, + } + def __init__(self, target_model: str): """ Initializes OVMS adapter. @@ -26,7 +46,7 @@ def __init__(self, target_model: str): """ import tritonclient.http as httpclient - service_url, self.model_name, self.model_version = _parse_model_arg( + service_url, self.model_name, self.model_version = self.parse_model_arg( target_model, ) self.client = httpclient.InferenceServerClient(service_url) @@ -83,7 +103,7 @@ def infer_sync(self, dict_data: dict) -> dict: Returns: dict: model raw outputs. """ - inputs = _prepare_inputs(dict_data, self.inputs) + inputs = self._prepare_inputs(dict_data) raw_result = self.client.infer( model_name=self.model_name, model_version=self.model_version, @@ -98,7 +118,7 @@ def infer_sync(self, dict_data: dict) -> dict: def infer_async(self, dict_data: dict, callback_data: Any): """A stub method imitating async inference with a blocking call.""" - inputs = _prepare_inputs(dict_data, self.inputs) + inputs = self._prepare_inputs(dict_data) raw_result = self.client.infer( model_name=self.model_name, model_version=self.model_version, @@ -166,60 +186,49 @@ def save_model(self, path: str, weights_path: str | None = None, version: str | msg = "OVMSAdapter does not support saving a model" raise NotImplementedError(msg) + @staticmethod + def is_ovms_model(target_model: str) -> bool: + """Checks if the given string is a valid OVMS model URL.""" + if not isinstance(target_model, str): + return False -_triton2np_precision = { - "INT64": np.int64, - "UINT64": np.uint64, - "FLOAT": np.float32, - "UINT32": np.uint32, - "INT32": np.int32, - "HALF": np.float16, - "INT16": np.int16, - "INT8": np.int8, - "UINT8": np.uint8, - "FP32": np.float32, -} - + return OVMSAdapter._OVMS_MODEL_URL_PATTERN.fullmatch(target_model) is not None -def _parse_model_arg(target_model: str): - """Parses OVMS model URL.""" - if not isinstance(target_model, str): - msg = "target_model must be str" - raise TypeError(msg) + @staticmethod + def parse_model_arg(target_model: str): + """Parses OVMS model URL.""" + if not isinstance(target_model, str): + msg = "target_model must be str" + raise TypeError(msg) - # Expected format:
:/v2/models/[/versions/] - match = re.fullmatch( - r"(.+)\/v2\/models\/([^\/]+)(?:(?:\/versions\/)(\d+))?(?:\/)?", - target_model, - ) - if not match: - msg = "invalid --model option format" - raise ValueError(msg) - - return match.group(1), match.group(2), match.group(3) or "" + match = OVMSAdapter._OVMS_MODEL_URL_PATTERN.fullmatch(target_model) + if not match: + msg = "invalid --model option format" + raise ValueError(msg) + return match.group(1), match.group(2), match.group(3) or "" -def _prepare_inputs(dict_data: dict, inputs_meta: dict[str, Metadata]): - """Converts raw model inputs into OVMS-specific representation.""" - import tritonclient.http as httpclient + def _prepare_inputs(self, dict_data: dict) -> list: + """Converts raw model inputs into OVMS-specific representation.""" + import tritonclient.http as httpclient - inputs = [] - for input_name, input_data in dict_data.items(): - if input_name not in inputs_meta: - msg = "Input data does not match model inputs" - raise ValueError(msg) - input_info = inputs_meta[input_name] - model_precision = _triton2np_precision[input_info.precision] - if isinstance(input_data, np.ndarray) and input_data.dtype != model_precision: - input_data = input_data.astype(model_precision) - elif isinstance(input_data, list): - input_data = np.array(input_data, dtype=model_precision) - - infer_input = httpclient.InferInput( - input_name, - input_data.shape, - input_info.precision, - ) - infer_input.set_data_from_numpy(input_data) - inputs.append(infer_input) - return inputs + inputs = [] + for input_name, input_data in dict_data.items(): + if input_name not in self.inputs: + msg = "Input data does not match model inputs" + raise ValueError(msg) + input_info = self.inputs[input_name] + model_precision = self._triton2np_precision[input_info.precision] + if isinstance(input_data, np.ndarray) and input_data.dtype != model_precision: + input_data = input_data.astype(model_precision) + elif isinstance(input_data, list): + input_data = np.array(input_data, dtype=model_precision) + + infer_input = httpclient.InferInput( + input_name, + input_data.shape, + input_info.precision, + ) + infer_input.set_data_from_numpy(input_data) + inputs.append(infer_input) + return inputs diff --git a/src/model_api/models/model.py b/src/model_api/models/model.py index 5dc4b1b1..47168dac 100644 --- a/src/model_api/models/model.py +++ b/src/model_api/models/model.py @@ -6,7 +6,6 @@ from __future__ import annotations import logging as log -import re from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, NoReturn, Type @@ -173,9 +172,7 @@ def create_model( inference_adapter: InferenceAdapter if isinstance(model, InferenceAdapter): inference_adapter = model - elif isinstance(model, str) and re.compile( - r"(.+)\/v2\/models\/([^\/]+)(?:(?:\/versions\/)(\d+))?(?:\/)?", - ).fullmatch(model): + elif isinstance(model, str) and OVMSAdapter.is_ovms_model(model): inference_adapter = OVMSAdapter(model) else: if core is None: diff --git a/tests/unit/adapters/test_ovms_adapter.py b/tests/unit/adapters/test_ovms_adapter.py index e6802666..fe308fc8 100644 --- a/tests/unit/adapters/test_ovms_adapter.py +++ b/tests/unit/adapters/test_ovms_adapter.py @@ -5,16 +5,16 @@ import pytest -from model_api.adapters.ovms_adapter import _parse_model_arg +from model_api.adapters.ovms_adapter import OVMSAdapter class TestParseModelArg: - """Test cases for the _parse_model_arg function.""" + """Test cases for the parse_model_arg function.""" def test_valid_url_with_version(self): """Test parsing a valid URL with version specified.""" target_model = "http://localhost:9000/v2/models/my_model/versions/123" - service_url, model_name, version = _parse_model_arg(target_model) + service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model) assert service_url == "http://localhost:9000" assert model_name == "my_model" @@ -23,7 +23,7 @@ def test_valid_url_with_version(self): def test_valid_url_without_version(self): """Test parsing a valid URL without version specified.""" target_model = "http://localhost:9000/v2/models/345$%^!@#$model" - service_url, model_name, version = _parse_model_arg(target_model) + service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model) assert service_url == "http://localhost:9000" assert model_name == "345$%^!@#$model" @@ -32,7 +32,7 @@ def test_valid_url_without_version(self): def test_valid_url_with_trailing_slash(self): """Test parsing a valid URL with trailing slash.""" target_model = "http://localhost:9000/v2/models/my_model/" - service_url, model_name, version = _parse_model_arg(target_model) + service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model) assert service_url == "http://localhost:9000" assert model_name == "my_model" @@ -41,7 +41,7 @@ def test_valid_url_with_trailing_slash(self): def test_valid_url_with_version_and_trailing_slash(self): """Test parsing a valid URL with version and trailing slash.""" target_model = "http://localhost:9000/v2/models/my_model/versions/456/" - service_url, model_name, version = _parse_model_arg(target_model) + service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model) assert service_url == "http://localhost:9000" assert model_name == "my_model" @@ -50,7 +50,7 @@ def test_valid_url_with_version_and_trailing_slash(self): def test_valid_url_https(self): """Test parsing a valid HTTPS URL.""" target_model = "https://example.com:8080/v2/models/test_model/versions/1" - service_url, model_name, version = _parse_model_arg(target_model) + service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model) assert service_url == "https://example.com:8080" assert model_name == "test_model" @@ -73,4 +73,4 @@ def test_valid_url_https(self): def test_invalid_url_formats(self, target_model, description): """Test parsing various invalid URL formats.""" with pytest.raises(ValueError, match="invalid --model option format"): - _parse_model_arg(target_model) + OVMSAdapter.parse_model_arg(target_model)