Skip to content

Commit d0348d6

Browse files
authored
Merge pull request #219 from MorpheusAIs/test
fix: case-insensitive model name resolution - MAIN
2 parents 96d5655 + a8fa639 commit d0348d6

2 files changed

Lines changed: 6 additions & 7 deletions

File tree

src/core/direct_model_service.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def __init__(self, cache_duration_seconds: int = 300):
3333
cache_duration_seconds: Cache TTL in seconds (default: 300 = 5 minutes)
3434
"""
3535
self.cache_duration = cache_duration_seconds
36-
self._model_mapping: Dict[str, str] = {} # name -> blockchain_id
36+
self._model_mapping: Dict[str, str] = {} # lowercase name -> blockchain_id
3737
self._id_to_name: Dict[str, str] = {} # blockchain_id -> name
38-
self._model_mapping_type: Dict[str, str] = {} # name -> type
38+
self._model_mapping_type: Dict[str, str] = {} # lowercase name -> type
3939
self._blockchain_ids: set = set()
4040
self._cache_expiry: Optional[datetime] = None
4141
self._last_etag: Optional[str] = None
@@ -102,8 +102,7 @@ async def resolve_model_id(self, model_identifier: str) -> Optional[str]:
102102
if model_identifier in self._blockchain_ids:
103103
return model_identifier
104104

105-
# Check if it's a model name
106-
return self._model_mapping.get(model_identifier)
105+
return self._model_mapping.get(model_identifier.lower())
107106

108107
async def get_model_name_from_id(self, blockchain_id: str) -> Optional[str]:
109108
"""
@@ -214,9 +213,9 @@ def _update_cache(self, models: List[Dict], content_hash: str, etag: Optional[st
214213
model_type = model.get("ModelType")
215214

216215
if model_name and blockchain_id:
217-
new_mapping[model_name] = blockchain_id
216+
new_mapping[model_name.lower()] = blockchain_id
218217
new_id_to_name[blockchain_id] = model_name
219-
new_mapping_type[model_name] = model_type
218+
new_mapping_type[model_name.lower()] = model_type
220219
new_blockchain_ids.add(blockchain_id)
221220

222221
self._model_mapping = new_mapping

src/core/model_routing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def _get_default_model_id(self, type: Optional[str] = "LLM") -> str:
110110
if default_model in model_mapping:
111111
logger.info("Using configured default model",
112112
default_model=default_model,
113-
blockchain_id=model_mapping[default_model],
113+
blockchain_id=model_mapping[default_model.lower()],
114114
event_type="default_model_resolved")
115115
return model_mapping[default_model]
116116

0 commit comments

Comments
 (0)