diff --git a/src/core/direct_model_service.py b/src/core/direct_model_service.py index 4251b467..86c395cf 100644 --- a/src/core/direct_model_service.py +++ b/src/core/direct_model_service.py @@ -33,9 +33,9 @@ def __init__(self, cache_duration_seconds: int = 300): cache_duration_seconds: Cache TTL in seconds (default: 300 = 5 minutes) """ self.cache_duration = cache_duration_seconds - self._model_mapping: Dict[str, str] = {} # name -> blockchain_id + self._model_mapping: Dict[str, str] = {} # lowercase name -> blockchain_id self._id_to_name: Dict[str, str] = {} # blockchain_id -> name - self._model_mapping_type: Dict[str, str] = {} # name -> type + self._model_mapping_type: Dict[str, str] = {} # lowercase name -> type self._blockchain_ids: set = set() self._cache_expiry: Optional[datetime] = None self._last_etag: Optional[str] = None @@ -102,8 +102,7 @@ async def resolve_model_id(self, model_identifier: str) -> Optional[str]: if model_identifier in self._blockchain_ids: return model_identifier - # Check if it's a model name - return self._model_mapping.get(model_identifier) + return self._model_mapping.get(model_identifier.lower()) async def get_model_name_from_id(self, blockchain_id: str) -> Optional[str]: """ @@ -214,9 +213,9 @@ def _update_cache(self, models: List[Dict], content_hash: str, etag: Optional[st model_type = model.get("ModelType") if model_name and blockchain_id: - new_mapping[model_name] = blockchain_id + new_mapping[model_name.lower()] = blockchain_id new_id_to_name[blockchain_id] = model_name - new_mapping_type[model_name] = model_type + new_mapping_type[model_name.lower()] = model_type new_blockchain_ids.add(blockchain_id) self._model_mapping = new_mapping diff --git a/src/core/model_routing.py b/src/core/model_routing.py index 478ef130..977bd258 100644 --- a/src/core/model_routing.py +++ b/src/core/model_routing.py @@ -110,7 +110,7 @@ async def _get_default_model_id(self, type: Optional[str] = "LLM") -> str: if default_model in model_mapping: logger.info("Using configured default model", default_model=default_model, - blockchain_id=model_mapping[default_model], + blockchain_id=model_mapping[default_model.lower()], event_type="default_model_resolved") return model_mapping[default_model]