Skip to content

Commit 569e68a

Browse files
authored
refactor: update model provider to use ModelFactory for instance mana… (#223)
* refactor: update model provider to use ModelFactory for instance management * refactor: integrate ModelFactory for model instance mapping in tests * refactor: remove unused imports and clean up test files
1 parent e2f6d04 commit 569e68a

5 files changed

Lines changed: 81 additions & 10 deletions

File tree

python/dify_plugin/core/plugin_registration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _resolve_model_providers(self):
320320
models[model_cls.model_type] = model_cls
321321

322322
model_factory = ModelFactory(provider, models)
323-
provider_instance = cls(provider, models) # type: ignore
323+
provider_instance = cls(provider, model_factory) # type: ignore
324324
self.models_mapping[provider.provider] = (
325325
provider,
326326
provider_instance,

python/dify_plugin/interfaces/model/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,30 @@
11
from abc import ABC, abstractmethod
2+
from typing import final
23

4+
from dify_plugin.core.model_factory import ModelFactory
35
from dify_plugin.entities.model import AIModelEntity, ModelType
46
from dify_plugin.entities.model.provider import ProviderEntity
57
from dify_plugin.interfaces.model.ai_model import AIModel
68

79

810
class ModelProvider(ABC):
911
provider_schema: ProviderEntity
10-
model_instance_map: dict[ModelType, AIModel]
12+
model_factory: ModelFactory
1113

14+
@final
1215
def __init__(
1316
self,
1417
provider_schemas: ProviderEntity,
15-
model_instance_map: dict[ModelType, AIModel],
18+
model_factory: ModelFactory,
1619
):
1720
"""
1821
Initialize model provider
1922
2023
:param provider_schemas: provider schemas
21-
:param model_instance_map: model instance map
24+
:param model_factory: model factory
2225
"""
2326
self.provider_schema = provider_schemas
24-
self.model_instance_map = model_instance_map
27+
self.model_factory = model_factory
2528

2629
@abstractmethod
2730
def validate_provider_credentials(self, credentials: dict) -> None:
@@ -71,7 +74,4 @@ def get_model_instance(self, model_type: ModelType) -> AIModel:
7174
:param model_type: model type defined in `ModelType`
7275
:return:
7376
"""
74-
if model_type in self.model_instance_map:
75-
return self.model_instance_map[model_type]
76-
77-
raise ValueError(f"Model instance not found for model type: {model_type}")
77+
return self.model_factory.get_instance(model_type)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from unittest.mock import MagicMock
2+
3+
from dify_plugin.core.model_factory import ModelFactory
4+
from dify_plugin.entities import I18nObject
5+
from dify_plugin.entities.model import ModelType
6+
from dify_plugin.entities.model.provider import ProviderEntity
7+
from dify_plugin.interfaces.model import ModelProvider
8+
9+
10+
def test_construct_model_provider():
11+
"""
12+
Ensure ModelProvider constructor is intact and usable.
13+
This guards against overriding or changing __init__ signature.
14+
"""
15+
16+
class ProviderImpl(ModelProvider):
17+
def validate_provider_credentials(self, credentials: dict) -> None:
18+
pass
19+
20+
provider_schema = ProviderEntity(
21+
provider="test",
22+
label=I18nObject(en_US="test"),
23+
supported_model_types=[ModelType.LLM],
24+
configurate_methods=[],
25+
)
26+
27+
model_factory = MagicMock(spec=ModelFactory)
28+
29+
provider = ProviderImpl(provider_schemas=provider_schema, model_factory=model_factory)
30+
31+
assert provider is not None
32+
assert provider.get_provider_schema() == provider_schema
33+
assert provider.model_factory is model_factory
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from unittest.mock import MagicMock
2+
3+
from dify_plugin.core.model_factory import ModelFactory
4+
from dify_plugin.entities import I18nObject
5+
from dify_plugin.entities.model import ModelType
6+
from dify_plugin.entities.model.provider import ProviderEntity
7+
from dify_plugin.interfaces.model import ModelProvider
8+
9+
10+
def test_model_provider_get_model_instance_delegates_to_factory():
11+
"""
12+
Ensure ModelProvider.get_model_instance forwards to ModelFactory.get_instance.
13+
Constructor usage mirrors test_construct_tool.py style (inline subclass, minimal init).
14+
"""
15+
16+
class MockModelProvider(ModelProvider):
17+
def validate_provider_credentials(self, credentials: dict) -> None:
18+
pass
19+
20+
provider_schema = ProviderEntity(
21+
provider="test",
22+
label=I18nObject(en_US="test"),
23+
supported_model_types=[ModelType.LLM],
24+
configurate_methods=[],
25+
)
26+
27+
model_factory = MagicMock(spec=ModelFactory)
28+
expected_instance = object()
29+
model_factory.get_instance.return_value = expected_instance
30+
31+
provider = MockModelProvider(provider_schemas=provider_schema, model_factory=model_factory)
32+
33+
result = provider.get_model_instance(ModelType.LLM)
34+
assert result is expected_instance
35+
model_factory.get_instance.assert_called_once_with(ModelType.LLM)

python/tests/test_model_registry_get_model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,10 @@ def mock_resolve_plugin_cls(self: PluginRegistration):
148148
supported_model_types=[ModelType.LLM],
149149
configurate_methods=[],
150150
),
151-
model_instance_map={},
151+
model_factory=ModelFactory(
152+
provider=provider_configuration,
153+
models={ModelType.LLM: MockLLM},
154+
),
152155
),
153156
ModelFactory(
154157
provider=provider_configuration,

0 commit comments

Comments
 (0)