diff --git a/src/model_api/adapters/ovms_adapter.py b/src/model_api/adapters/ovms_adapter.py
index df0581a8..52b99363 100644
--- a/src/model_api/adapters/ovms_adapter.py
+++ b/src/model_api/adapters/ovms_adapter.py
@@ -1,12 +1,12 @@
#
-# Copyright (C) 2020-2024 Intel Corporation
+# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#
from __future__ import annotations
import re
-from typing import Any, Callable
+from typing import Any, Callable, Final
import numpy as np
@@ -17,6 +17,26 @@
class OVMSAdapter(InferenceAdapter):
"""Inference adapter that allows working with models served by the OpenVINO Model Server"""
+ # OVMS model URL regex pattern
+ # Expected format:
:/v2/models/[/versions/]
+ _OVMS_MODEL_URL_PATTERN: Final = re.compile(
+ r"(.+)\/v2\/models\/([^\/]+)(?:(?:\/versions\/)(\d+))?(?:\/)?",
+ )
+
+ # Triton to NumPy precision mapping
+ _triton2np_precision: Final = {
+ "INT64": np.int64,
+ "UINT64": np.uint64,
+ "FLOAT": np.float32,
+ "UINT32": np.uint32,
+ "INT32": np.int32,
+ "HALF": np.float16,
+ "INT16": np.int16,
+ "INT8": np.int8,
+ "UINT8": np.uint8,
+ "FP32": np.float32,
+ }
+
def __init__(self, target_model: str):
"""
Initializes OVMS adapter.
@@ -26,7 +46,7 @@ def __init__(self, target_model: str):
"""
import tritonclient.http as httpclient
- service_url, self.model_name, self.model_version = _parse_model_arg(
+ service_url, self.model_name, self.model_version = self.parse_model_arg(
target_model,
)
self.client = httpclient.InferenceServerClient(service_url)
@@ -83,7 +103,7 @@ def infer_sync(self, dict_data: dict) -> dict:
Returns:
dict: model raw outputs.
"""
- inputs = _prepare_inputs(dict_data, self.inputs)
+ inputs = self._prepare_inputs(dict_data)
raw_result = self.client.infer(
model_name=self.model_name,
model_version=self.model_version,
@@ -98,7 +118,7 @@ def infer_sync(self, dict_data: dict) -> dict:
def infer_async(self, dict_data: dict, callback_data: Any):
"""A stub method imitating async inference with a blocking call."""
- inputs = _prepare_inputs(dict_data, self.inputs)
+ inputs = self._prepare_inputs(dict_data)
raw_result = self.client.infer(
model_name=self.model_name,
model_version=self.model_version,
@@ -166,65 +186,49 @@ def save_model(self, path: str, weights_path: str | None = None, version: str |
msg = "OVMSAdapter does not support saving a model"
raise NotImplementedError(msg)
+ @staticmethod
+ def is_ovms_model(target_model: str) -> bool:
+ """Checks if the given string is a valid OVMS model URL."""
+ if not isinstance(target_model, str):
+ return False
-_triton2np_precision = {
- "INT64": np.int64,
- "UINT64": np.uint64,
- "FLOAT": np.float32,
- "UINT32": np.uint32,
- "INT32": np.int32,
- "HALF": np.float16,
- "INT16": np.int16,
- "INT8": np.int8,
- "UINT8": np.uint8,
- "FP32": np.float32,
-}
-
-
-def _parse_model_arg(target_model: str):
- """Parses OVMS model URL."""
- if not isinstance(target_model, str):
- msg = "target_model must be str"
- raise TypeError(msg)
- # Expected format: :/models/[:]
- if not re.fullmatch(
- r"(\w+\.*\-*)*\w+:\d+\/v2/models\/[a-zA-Z0-9._-]+(\:\d+)*",
- target_model,
- ):
- msg = "invalid --model option format"
- raise ValueError(msg)
- service_url, _, _, model = target_model.split("/")
- model_spec = model.split(":")
- if len(model_spec) == 1:
- # model version not specified - use latest
- return service_url, model_spec[0], ""
- if len(model_spec) == 2:
- return service_url, model_spec[0], model_spec[1]
- msg = "Invalid target_model format"
- raise ValueError(msg)
-
-
-def _prepare_inputs(dict_data: dict, inputs_meta: dict[str, Metadata]):
- """Converts raw model inputs into OVMS-specific representation."""
- import tritonclient.http as httpclient
-
- inputs = []
- for input_name, input_data in dict_data.items():
- if input_name not in inputs_meta:
- msg = "Input data does not match model inputs"
+ return OVMSAdapter._OVMS_MODEL_URL_PATTERN.fullmatch(target_model) is not None
+
+ @staticmethod
+ def parse_model_arg(target_model: str):
+ """Parses OVMS model URL."""
+ if not isinstance(target_model, str):
+ msg = "target_model must be str"
+ raise TypeError(msg)
+
+ match = OVMSAdapter._OVMS_MODEL_URL_PATTERN.fullmatch(target_model)
+ if not match:
+ msg = "invalid --model option format"
raise ValueError(msg)
- input_info = inputs_meta[input_name]
- model_precision = _triton2np_precision[input_info.precision]
- if isinstance(input_data, np.ndarray) and input_data.dtype != model_precision:
- input_data = input_data.astype(model_precision)
- elif isinstance(input_data, list):
- input_data = np.array(input_data, dtype=model_precision)
-
- infer_input = httpclient.InferInput(
- input_name,
- input_data.shape,
- input_info.precision,
- )
- infer_input.set_data_from_numpy(input_data)
- inputs.append(infer_input)
- return inputs
+
+ return match.group(1), match.group(2), match.group(3) or ""
+
+ def _prepare_inputs(self, dict_data: dict) -> list:
+ """Converts raw model inputs into OVMS-specific representation."""
+ import tritonclient.http as httpclient
+
+ inputs = []
+ for input_name, input_data in dict_data.items():
+ if input_name not in self.inputs:
+ msg = "Input data does not match model inputs"
+ raise ValueError(msg)
+ input_info = self.inputs[input_name]
+ model_precision = self._triton2np_precision[input_info.precision]
+ if isinstance(input_data, np.ndarray) and input_data.dtype != model_precision:
+ input_data = input_data.astype(model_precision)
+ elif isinstance(input_data, list):
+ input_data = np.array(input_data, dtype=model_precision)
+
+ infer_input = httpclient.InferInput(
+ input_name,
+ input_data.shape,
+ input_info.precision,
+ )
+ infer_input.set_data_from_numpy(input_data)
+ inputs.append(infer_input)
+ return inputs
diff --git a/src/model_api/models/model.py b/src/model_api/models/model.py
index ccabdc87..47168dac 100644
--- a/src/model_api/models/model.py
+++ b/src/model_api/models/model.py
@@ -6,7 +6,6 @@
from __future__ import annotations
import logging as log
-import re
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, NoReturn, Type
@@ -173,9 +172,7 @@ def create_model(
inference_adapter: InferenceAdapter
if isinstance(model, InferenceAdapter):
inference_adapter = model
- elif isinstance(model, str) and re.compile(
- r"(\w+\.*\-*)*\w+:\d+\/v2/models\/[a-zA-Z0-9._-]+(\:\d+)*",
- ).fullmatch(model):
+ elif isinstance(model, str) and OVMSAdapter.is_ovms_model(model):
inference_adapter = OVMSAdapter(model)
else:
if core is None:
diff --git a/tests/unit/adapters/test_ovms_adapter.py b/tests/unit/adapters/test_ovms_adapter.py
new file mode 100644
index 00000000..fe308fc8
--- /dev/null
+++ b/tests/unit/adapters/test_ovms_adapter.py
@@ -0,0 +1,76 @@
+#
+# Copyright (C) 2025 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+#
+
+import pytest
+
+from model_api.adapters.ovms_adapter import OVMSAdapter
+
+
+class TestParseModelArg:
+ """Test cases for the parse_model_arg function."""
+
+ def test_valid_url_with_version(self):
+ """Test parsing a valid URL with version specified."""
+ target_model = "http://localhost:9000/v2/models/my_model/versions/123"
+ service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model)
+
+ assert service_url == "http://localhost:9000"
+ assert model_name == "my_model"
+ assert version == "123"
+
+ def test_valid_url_without_version(self):
+ """Test parsing a valid URL without version specified."""
+ target_model = "http://localhost:9000/v2/models/345$%^!@#$model"
+ service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model)
+
+ assert service_url == "http://localhost:9000"
+ assert model_name == "345$%^!@#$model"
+ assert version == ""
+
+ def test_valid_url_with_trailing_slash(self):
+ """Test parsing a valid URL with trailing slash."""
+ target_model = "http://localhost:9000/v2/models/my_model/"
+ service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model)
+
+ assert service_url == "http://localhost:9000"
+ assert model_name == "my_model"
+ assert version == ""
+
+ def test_valid_url_with_version_and_trailing_slash(self):
+ """Test parsing a valid URL with version and trailing slash."""
+ target_model = "http://localhost:9000/v2/models/my_model/versions/456/"
+ service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model)
+
+ assert service_url == "http://localhost:9000"
+ assert model_name == "my_model"
+ assert version == "456"
+
+ def test_valid_url_https(self):
+ """Test parsing a valid HTTPS URL."""
+ target_model = "https://example.com:8080/v2/models/test_model/versions/1"
+ service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model)
+
+ assert service_url == "https://example.com:8080"
+ assert model_name == "test_model"
+ assert version == "1"
+
+ @pytest.mark.parametrize(
+ ("target_model", "description"),
+ [
+ ("http://localhost:9000/models/my_model", "missing v2/models path"),
+ ("http://localhost:9000/v2/models/my_model/version/123", "wrong versions format"),
+ ("http://localhost:9000/v2/models//versions/123", "empty model name"),
+ ("http://localhost:9000/v2/models/", "no model name"),
+ ("http://localhost:9000/v2", "incomplete URL"),
+ ("http://localhost:9000/v2/models/my_model/versions/latest", "non-numeric version"),
+ ("http://localhost:9000/v2/models/my_model/extra/path", "extra path"),
+ ("http://localhost:9000/v2/models/my_model/versions/", "no version specified"),
+ ("", "empty"),
+ ],
+ )
+ def test_invalid_url_formats(self, target_model, description):
+ """Test parsing various invalid URL formats."""
+ with pytest.raises(ValueError, match="invalid --model option format"):
+ OVMSAdapter.parse_model_arg(target_model)
diff --git a/tests/unit/test_utils.py b/tests/unit/adapters/test_utils.py
similarity index 100%
rename from tests/unit/test_utils.py
rename to tests/unit/adapters/test_utils.py
diff --git a/tests/unit/test_types.py b/tests/unit/models/test_types.py
similarity index 100%
rename from tests/unit/test_types.py
rename to tests/unit/models/test_types.py