Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 69 additions & 65 deletions src/model_api/adapters/ovms_adapter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#
# Copyright (C) 2020-2024 Intel Corporation
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from __future__ import annotations

import re
from typing import Any, Callable
from typing import Any, Callable, Final

import numpy as np

Expand All @@ -17,6 +17,26 @@
class OVMSAdapter(InferenceAdapter):
"""Inference adapter that allows working with models served by the OpenVINO Model Server"""

# OVMS model URL regex pattern
# Expected format: <address>:<port>/v2/models/<model_name>[/versions/<model_version>]
_OVMS_MODEL_URL_PATTERN: Final = re.compile(
r"(.+)\/v2\/models\/([^\/]+)(?:(?:\/versions\/)(\d+))?(?:\/)?",
)

# Triton to NumPy precision mapping
_triton2np_precision: Final = {
"INT64": np.int64,
"UINT64": np.uint64,
"FLOAT": np.float32,
"UINT32": np.uint32,
"INT32": np.int32,
"HALF": np.float16,
"INT16": np.int16,
"INT8": np.int8,
"UINT8": np.uint8,
"FP32": np.float32,
}

def __init__(self, target_model: str):
"""
Initializes OVMS adapter.
Expand All @@ -26,7 +46,7 @@ def __init__(self, target_model: str):
"""
import tritonclient.http as httpclient

service_url, self.model_name, self.model_version = _parse_model_arg(
service_url, self.model_name, self.model_version = self.parse_model_arg(
target_model,
)
self.client = httpclient.InferenceServerClient(service_url)
Expand Down Expand Up @@ -83,7 +103,7 @@ def infer_sync(self, dict_data: dict) -> dict:
Returns:
dict: model raw outputs.
"""
inputs = _prepare_inputs(dict_data, self.inputs)
inputs = self._prepare_inputs(dict_data)
raw_result = self.client.infer(
model_name=self.model_name,
model_version=self.model_version,
Expand All @@ -98,7 +118,7 @@ def infer_sync(self, dict_data: dict) -> dict:

def infer_async(self, dict_data: dict, callback_data: Any):
"""A stub method imitating async inference with a blocking call."""
inputs = _prepare_inputs(dict_data, self.inputs)
inputs = self._prepare_inputs(dict_data)
raw_result = self.client.infer(
model_name=self.model_name,
model_version=self.model_version,
Expand Down Expand Up @@ -166,65 +186,49 @@ def save_model(self, path: str, weights_path: str | None = None, version: str |
msg = "OVMSAdapter does not support saving a model"
raise NotImplementedError(msg)

@staticmethod
def is_ovms_model(target_model: str) -> bool:
"""Checks if the given string is a valid OVMS model URL."""
if not isinstance(target_model, str):
return False

_triton2np_precision = {
"INT64": np.int64,
"UINT64": np.uint64,
"FLOAT": np.float32,
"UINT32": np.uint32,
"INT32": np.int32,
"HALF": np.float16,
"INT16": np.int16,
"INT8": np.int8,
"UINT8": np.uint8,
"FP32": np.float32,
}


def _parse_model_arg(target_model: str):
"""Parses OVMS model URL."""
if not isinstance(target_model, str):
msg = "target_model must be str"
raise TypeError(msg)
# Expected format: <address>:<port>/models/<model_name>[:<model_version>]
if not re.fullmatch(
r"(\w+\.*\-*)*\w+:\d+\/v2/models\/[a-zA-Z0-9._-]+(\:\d+)*",
target_model,
):
msg = "invalid --model option format"
raise ValueError(msg)
service_url, _, _, model = target_model.split("/")
model_spec = model.split(":")
if len(model_spec) == 1:
# model version not specified - use latest
return service_url, model_spec[0], ""
if len(model_spec) == 2:
return service_url, model_spec[0], model_spec[1]
msg = "Invalid target_model format"
raise ValueError(msg)


def _prepare_inputs(dict_data: dict, inputs_meta: dict[str, Metadata]):
"""Converts raw model inputs into OVMS-specific representation."""
import tritonclient.http as httpclient

inputs = []
for input_name, input_data in dict_data.items():
if input_name not in inputs_meta:
msg = "Input data does not match model inputs"
return OVMSAdapter._OVMS_MODEL_URL_PATTERN.fullmatch(target_model) is not None

@staticmethod
def parse_model_arg(target_model: str):
"""Parses OVMS model URL."""
if not isinstance(target_model, str):
msg = "target_model must be str"
raise TypeError(msg)

match = OVMSAdapter._OVMS_MODEL_URL_PATTERN.fullmatch(target_model)
if not match:
msg = "invalid --model option format"
raise ValueError(msg)
input_info = inputs_meta[input_name]
model_precision = _triton2np_precision[input_info.precision]
if isinstance(input_data, np.ndarray) and input_data.dtype != model_precision:
input_data = input_data.astype(model_precision)
elif isinstance(input_data, list):
input_data = np.array(input_data, dtype=model_precision)

infer_input = httpclient.InferInput(
input_name,
input_data.shape,
input_info.precision,
)
infer_input.set_data_from_numpy(input_data)
inputs.append(infer_input)
return inputs

return match.group(1), match.group(2), match.group(3) or ""

def _prepare_inputs(self, dict_data: dict) -> list:
"""Converts raw model inputs into OVMS-specific representation."""
import tritonclient.http as httpclient

inputs = []
for input_name, input_data in dict_data.items():
if input_name not in self.inputs:
msg = "Input data does not match model inputs"
raise ValueError(msg)
input_info = self.inputs[input_name]
model_precision = self._triton2np_precision[input_info.precision]
if isinstance(input_data, np.ndarray) and input_data.dtype != model_precision:
input_data = input_data.astype(model_precision)
elif isinstance(input_data, list):
input_data = np.array(input_data, dtype=model_precision)

infer_input = httpclient.InferInput(
input_name,
input_data.shape,
input_info.precision,
)
infer_input.set_data_from_numpy(input_data)
inputs.append(infer_input)
return inputs
5 changes: 1 addition & 4 deletions src/model_api/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import logging as log
import re
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, NoReturn, Type

Expand Down Expand Up @@ -173,9 +172,7 @@ def create_model(
inference_adapter: InferenceAdapter
if isinstance(model, InferenceAdapter):
inference_adapter = model
elif isinstance(model, str) and re.compile(
r"(\w+\.*\-*)*\w+:\d+\/v2/models\/[a-zA-Z0-9._-]+(\:\d+)*",
).fullmatch(model):
elif isinstance(model, str) and OVMSAdapter.is_ovms_model(model):
inference_adapter = OVMSAdapter(model)
else:
if core is None:
Expand Down
76 changes: 76 additions & 0 deletions tests/unit/adapters/test_ovms_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

import pytest

from model_api.adapters.ovms_adapter import OVMSAdapter


class TestParseModelArg:
"""Test cases for the parse_model_arg function."""

def test_valid_url_with_version(self):
"""Test parsing a valid URL with version specified."""
target_model = "http://localhost:9000/v2/models/my_model/versions/123"
service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model)

assert service_url == "http://localhost:9000"
assert model_name == "my_model"
assert version == "123"

def test_valid_url_without_version(self):
"""Test parsing a valid URL without version specified."""
target_model = "http://localhost:9000/v2/models/345$%^!@#$model"
service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model)

assert service_url == "http://localhost:9000"
assert model_name == "345$%^!@#$model"
assert version == ""

def test_valid_url_with_trailing_slash(self):
"""Test parsing a valid URL with trailing slash."""
target_model = "http://localhost:9000/v2/models/my_model/"
service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model)

assert service_url == "http://localhost:9000"
assert model_name == "my_model"
assert version == ""

def test_valid_url_with_version_and_trailing_slash(self):
"""Test parsing a valid URL with version and trailing slash."""
target_model = "http://localhost:9000/v2/models/my_model/versions/456/"
service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model)

assert service_url == "http://localhost:9000"
assert model_name == "my_model"
assert version == "456"

def test_valid_url_https(self):
"""Test parsing a valid HTTPS URL."""
target_model = "https://example.com:8080/v2/models/test_model/versions/1"
service_url, model_name, version = OVMSAdapter.parse_model_arg(target_model)

assert service_url == "https://example.com:8080"
assert model_name == "test_model"
assert version == "1"

@pytest.mark.parametrize(
("target_model", "description"),
[
("http://localhost:9000/models/my_model", "missing v2/models path"),
("http://localhost:9000/v2/models/my_model/version/123", "wrong versions format"),
("http://localhost:9000/v2/models//versions/123", "empty model name"),
("http://localhost:9000/v2/models/", "no model name"),
("http://localhost:9000/v2", "incomplete URL"),
("http://localhost:9000/v2/models/my_model/versions/latest", "non-numeric version"),
("http://localhost:9000/v2/models/my_model/extra/path", "extra path"),
("http://localhost:9000/v2/models/my_model/versions/", "no version specified"),
("", "empty"),
],
)
def test_invalid_url_formats(self, target_model, description):
"""Test parsing various invalid URL formats."""
with pytest.raises(ValueError, match="invalid --model option format"):
OVMSAdapter.parse_model_arg(target_model)
File renamed without changes.
File renamed without changes.