diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 4246cac91..04d7150ad 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -1,4 +1,4 @@ -import threading +import threading import logging from typing import List, Optional from urllib.parse import urljoin @@ -450,6 +450,7 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int rerank = param_dict.get("rerank", False) rerank_model_name = param_dict.get("rerank_model_name", "") rerank_model = None + is_multimodal = bool(tool_config.params.pop("multimodal", False)) if rerank and rerank_model_name: rerank_model = get_rerank_model( tenant_id=tenant_id, model_name=rerank_model_name @@ -457,7 +458,9 @@ async def create_tool_config_list(agent_id, tenant_id, user_id, version_no: int tool_config.metadata = { "vdb_core": get_vector_db_core(), - "embedding_model": get_embedding_model(tenant_id=tenant_id), + "embedding_model": get_embedding_model( + tenant_id=tenant_id, is_multimodal=is_multimodal + ), "rerank_model": rerank_model, } elif tool_config.class_name in ["DifySearchTool", "DataMateSearchTool"]: diff --git a/backend/apps/file_management_app.py b/backend/apps/file_management_app.py index 50224c952..bb5cbb318 100644 --- a/backend/apps/file_management_app.py +++ b/backend/apps/file_management_app.py @@ -116,12 +116,13 @@ async def upload_files( @file_management_config_router.post("/process") async def process_files( - files: List[dict] = Body( - ..., description="List of file details to process, including path_or_url and filename"), - chunking_strategy: Optional[str] = Body("basic"), - index_name: str = Body(...), - destination: str = Body(...), - authorization: Optional[str] = Header(None) + files: Annotated[List[dict], Body( + ..., description="List of file details to process, including path_or_url and filename")], + index_name: Annotated[str, Body(...)], + destination: Annotated[str, Body(...)], + chunking_strategy: Annotated[Optional[str], Body(...)] = "basic", + model_id: Annotated[Optional[int], Body(...)] = None, + authorization: Annotated[Optional[str], Header()] = None ): """ Trigger data processing for a list of uploaded files. @@ -134,7 +135,8 @@ async def process_files( chunking_strategy=chunking_strategy, source_type=destination, index_name=index_name, - authorization=authorization + authorization=authorization, + model_id=model_id ) process_result = await trigger_data_process(files, process_params) diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py index 0a5a04139..c04c577f5 100644 --- a/backend/apps/model_managment_app.py +++ b/backend/apps/model_managment_app.py @@ -33,7 +33,7 @@ from fastapi.responses import JSONResponse from fastapi.encoders import jsonable_encoder from http import HTTPStatus -from typing import List, Optional +from typing import Annotated, List, Optional from services.model_health_service import ( check_model_connectivity, verify_model_config_connectivity, @@ -297,7 +297,8 @@ async def get_llm_model_list(authorization: Optional[str] = Header(None)): @router.post("/healthcheck") async def check_model_health( - display_name: str = Query(..., description="Display name to check"), + display_name: Annotated[str, Query(..., description="Display name to check")], + model_type: Annotated[str, Query(..., description="...")], authorization: Optional[str] = Header(None) ): """Check and update model connectivity, returning the latest status. @@ -308,7 +309,7 @@ async def check_model_health( """ try: _, tenant_id = get_current_user_id(authorization) - result = await check_model_connectivity(display_name, tenant_id) + result = await check_model_connectivity(display_name, tenant_id, model_type) return JSONResponse(status_code=HTTPStatus.OK, content={ "message": "Successfully checked model connectivity", "data": result diff --git a/backend/apps/vectordatabase_app.py b/backend/apps/vectordatabase_app.py index 872b5387b..7f948e625 100644 --- a/backend/apps/vectordatabase_app.py +++ b/backend/apps/vectordatabase_app.py @@ -65,11 +65,13 @@ def create_new_index( # Extract optional fields from request body ingroup_permission = None group_ids = None - embedding_model_name = None + is_multimodal = False + embedding_model_name: Optional[str] = None if request: ingroup_permission = request.get("ingroup_permission") group_ids = request.get("group_ids") - embedding_model_name = request.get("embedding_model_name") + is_multimodal = request.get("is_multimodal", False) + embedding_model_name = request.get("embeddingModel") # Treat path parameter as user-facing knowledge base name for new creations return ElasticSearchService.create_knowledge_base( @@ -81,6 +83,7 @@ def create_new_index( ingroup_permission=ingroup_permission, group_ids=group_ids, embedding_model_name=embedding_model_name, + is_multimodal=is_multimodal, ) except Exception as e: raise HTTPException( @@ -124,6 +127,7 @@ async def update_index( knowledge_name = request.get("knowledge_name") ingroup_permission = request.get("ingroup_permission") group_ids = request.get("group_ids") + is_multimodal = request.get("is_multimodal") # Call service layer to update knowledge base result = ElasticSearchService.update_knowledge_base( @@ -131,6 +135,7 @@ async def update_index( knowledge_name=knowledge_name, ingroup_permission=ingroup_permission, group_ids=group_ids, + is_multimodal=is_multimodal, tenant_id=tenant_id, user_id=user_id, ) @@ -200,13 +205,23 @@ def create_index_documents( user_id, tenant_id = get_current_user_id(authorization) # Get the knowledge base record to retrieve the saved embedding model - knowledge_record = get_knowledge_record({'index_name': index_name}) + knowledge_record = get_knowledge_record( + {"index_name": index_name, "tenant_id": tenant_id} + ) saved_embedding_model_name = None if knowledge_record: saved_embedding_model_name = knowledge_record.get('embedding_model_name') - - # Use the saved model from knowledge base, fallback to tenant default if not set - embedding_model = get_embedding_model(tenant_id, saved_embedding_model_name) + is_multimodal = ( + True if knowledge_record and knowledge_record.get('is_multimodal') == 'Y' else False + ) + + # Use the saved model from knowledge base, fallback to tenant default if not set. + embedding_model = get_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=saved_embedding_model_name, + strict_model_name=bool(saved_embedding_model_name), + ) return ElasticSearchService.index_documents( embedding_model=embedding_model, @@ -463,6 +478,7 @@ def update_chunk( chunk_request=payload, vdb_core=vdb_core, user_id=user_id, + tenant_id=tenant_id, ) return JSONResponse(status_code=HTTPStatus.OK, content=result) except ValueError as e: @@ -529,8 +545,17 @@ async def hybrid_search( """Run a hybrid (accurate + semantic) search across indices.""" try: _, tenant_id = get_current_user_id(authorization) + resolved_index_names: List[str] = [] + for requested_name in payload.index_names: + try: + resolved_name = get_index_name_by_knowledge_name( + requested_name, tenant_id + ) + except Exception: + resolved_name = requested_name + resolved_index_names.append(resolved_name) result = ElasticSearchService.search_hybrid( - index_names=payload.index_names, + index_names=resolved_index_names, query=payload.query, tenant_id=tenant_id, top_k=payload.top_k, diff --git a/backend/consts/const.py b/backend/consts/const.py index 223a1d00b..ee7aa63b1 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -28,6 +28,10 @@ class VectorDatabaseType(str, Enum): # Data Processing Service Configuration DATA_PROCESS_SERVICE = os.getenv("DATA_PROCESS_SERVICE") CLIP_MODEL_PATH = os.getenv("CLIP_MODEL_PATH") +TABLE_TRANSFORMER_MODEL_PATH = os.getenv("TABLE_TRANSFORMER_MODEL_PATH") +UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH = os.getenv( + "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" +) # Upload Configuration @@ -115,6 +119,7 @@ class VectorDatabaseType(str, Enum): MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY") MINIO_REGION = os.getenv("MINIO_REGION") MINIO_DEFAULT_BUCKET = os.getenv("MINIO_DEFAULT_BUCKET") +S3_URL_PREFIX = "s3://" # Postgres Configuration diff --git a/backend/consts/model.py b/backend/consts/model.py index 707802957..c8acfa3d1 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -234,6 +234,7 @@ class ProcessParams(BaseModel): source_type: str index_name: str authorization: Optional[str] = None + model_id: Optional[int] = None class OpinionRequest(BaseModel): diff --git a/backend/data_process/ray_actors.py b/backend/data_process/ray_actors.py index 2fa590bec..b9fd982ae 100644 --- a/backend/data_process/ray_actors.py +++ b/backend/data_process/ray_actors.py @@ -1,11 +1,19 @@ +from io import BytesIO import logging import json from typing import Any, Dict, List, Optional import ray -from consts.const import RAY_ACTOR_NUM_CPUS, REDIS_BACKEND_URL, DEFAULT_EXPECTED_CHUNK_SIZE, DEFAULT_MAXIMUM_CHUNK_SIZE -from database.attachment_db import get_file_stream +from consts.const import ( + RAY_ACTOR_NUM_CPUS, + REDIS_BACKEND_URL, + DEFAULT_EXPECTED_CHUNK_SIZE, + DEFAULT_MAXIMUM_CHUNK_SIZE, + TABLE_TRANSFORMER_MODEL_PATH, + UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH, +) +from database.attachment_db import build_s3_url, get_file_stream, upload_fileobj from database.model_management_db import get_model_by_model_id from nexent.data_process import DataProcessCore @@ -58,50 +66,137 @@ def process_file( if task_id: params['task_id'] = task_id - # Get chunk size parameters from embedding model if model_id is provided - if model_id and tenant_id: - try: - # Get embedding model details directly by model_id - model_record = get_model_by_model_id( - model_id=model_id, tenant_id=tenant_id) - if model_record: - expected_chunk_size = model_record.get( - 'expected_chunk_size', DEFAULT_EXPECTED_CHUNK_SIZE) - maximum_chunk_size = model_record.get( - 'maximum_chunk_size', DEFAULT_MAXIMUM_CHUNK_SIZE) - model_name = model_record.get('display_name') - - # Pass chunk sizes to processing parameters - params['max_characters'] = maximum_chunk_size - params['new_after_n_chars'] = expected_chunk_size - - logger.info( - f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): " - f"max_characters={maximum_chunk_size}, new_after_n_chars={expected_chunk_size}") - else: - logger.warning( - f"[RayActor] Embedding model with ID {model_id} not found for tenant '{tenant_id}', using default chunk sizes") - except Exception as e: + self._apply_model_chunk_sizes( + model_id=model_id, tenant_id=tenant_id, params=params) + self._apply_model_paths(params) + file_data = self._read_file_bytes(source) + + result = self._processor.file_process( + file_data=file_data, + filename=source, + chunking_strategy=chunking_strategy, + **params + ) + chunks, images_info = self._normalize_processor_result(result) + if images_info: + self._append_image_chunks( + source=source, chunks=chunks, images_info=images_info) + + chunks = self._validate_chunks(chunks, source) + if not chunks: + return [] + + logger.info( + f"[RayActor] Processing done: produced {len(chunks)} chunks for source='{source}'") + return chunks + + def _apply_model_paths(self, params: Dict[str, Any]) -> None: + params["table_transformer_model_path"] = TABLE_TRANSFORMER_MODEL_PATH + params[ + "unstructured_default_model_initialize_params_json_path" + ] = UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH + + def _apply_model_chunk_sizes( + self, + model_id: Optional[int], + tenant_id: Optional[str], + params: Dict[str, Any], + ) -> None: + if not (model_id and tenant_id): + return + + try: + model_record = get_model_by_model_id( + model_id=model_id, tenant_id=tenant_id) + if not model_record: logger.warning( - f"[RayActor] Failed to retrieve chunk sizes from embedding model ID {model_id}: {e}. Using default chunk sizes") + f"[RayActor] Embedding model with ID {model_id} not found for tenant '{tenant_id}', using default chunk sizes") + return + + expected_chunk_size = model_record.get( + 'expected_chunk_size', DEFAULT_EXPECTED_CHUNK_SIZE) + maximum_chunk_size = model_record.get( + 'maximum_chunk_size', DEFAULT_MAXIMUM_CHUNK_SIZE) + model_name = model_record.get('display_name') + model_type = model_record.get('model_type') + params['max_characters'] = maximum_chunk_size + params['new_after_n_chars'] = expected_chunk_size + if model_type: + params['model_type'] = model_type + + logger.info( + f"[RayActor] Using chunk sizes from embedding model '{model_name}' (ID: {model_id}): " + f"max_characters={maximum_chunk_size}, new_after_n_chars={expected_chunk_size}") + except Exception as e: + logger.warning( + f"[RayActor] Failed to retrieve chunk sizes from embedding model ID {model_id}: {e}. Using default chunk sizes") + + def _read_file_bytes(self, source: str) -> bytes: try: file_stream = get_file_stream(source) if file_stream is None: raise FileNotFoundError( f"Unable to fetch file from URL: {source}") - file_data = file_stream.read() + return file_stream.read() except Exception as e: logger.error(f"Failed to fetch file from {source}: {e}") raise - chunks = self._processor.file_process( - file_data=file_data, - filename=source, - chunking_strategy=chunking_strategy, - **params - ) + def _normalize_processor_result( + self, result: Any + ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + if isinstance(result, tuple) and len(result) == 2: + chunks, images_info = result + return chunks or [], images_info or [] + return result or [], [] + + def _append_image_chunks( + self, + source: str, + chunks: List[Dict[str, Any]], + images_info: List[Dict[str, Any]], + ) -> None: + folder = "images_in_attachments" + for index, image_data in enumerate(images_info): + if not isinstance(image_data, dict): + logger.warning( + f"[RayActor] Skipping image entry at index {index}: unexpected type {type(image_data)}" + ) + continue + if "image_bytes" not in image_data: + logger.warning( + f"[RayActor] Skipping image entry at index {index}: missing image_bytes" + ) + continue + + img_obj = BytesIO(image_data["image_bytes"]) + result = upload_fileobj( + file_obj=img_obj, + file_name=f"{index}.{image_data['image_format']}", + prefix=folder) + image_url = build_s3_url(result.get("object_name", "")) + + image_data["source_file"] = source + image_data["image_url"] = image_url + chunks.append({ + "content": json.dumps({ + "source_file": source, + "position": image_data["position"], + "image_url": image_url, + }), + "filename": source, + "metadata": { + "chunk_index": len(chunks) + index, + "process_source": "UniversalImageExtractor", + "image_url": image_url, + } + }) + + def _validate_chunks( + self, chunks: Any, source: str + ) -> List[Dict[str, Any]]: if chunks is None: logger.warning( f"[RayActor] file_process returned None for source='{source}'") @@ -114,9 +209,6 @@ def process_file( logger.warning( f"[RayActor] file_process returned empty list for source='{source}'") return [] - - logger.info( - f"[RayActor] Processing done: produced {len(chunks)} chunks for source='{source}'") return chunks def store_chunks_in_redis(self, redis_key: str, chunks: List[Dict[str, Any]]) -> bool: diff --git a/backend/database/attachment_db.py b/backend/database/attachment_db.py index 1faabac23..c8ed3a37d 100644 --- a/backend/database/attachment_db.py +++ b/backend/database/attachment_db.py @@ -2,9 +2,62 @@ import os import uuid from datetime import datetime -from typing import Any, BinaryIO, Dict, List, Optional +from typing import Any, BinaryIO, Dict, List, Optional, Tuple from .client import minio_client +from consts.const import S3_URL_PREFIX + + +def _normalize_object_and_bucket(object_name: str, bucket: Optional[str] = None) -> Tuple[str, Optional[str]]: + """ + Normalize object_name + bucket from supported URL styles. + + Supports: + - s3://bucket/key + - /bucket/key + - key (uses provided bucket or default bucket) + """ + if not object_name: + return object_name, bucket + + if object_name.startswith(S3_URL_PREFIX): + s3_path = object_name[len(S3_URL_PREFIX) :] + parts = s3_path.split("/", 1) + parsed_bucket = parts[0] if parts[0] else None + parsed_key = parts[1] if len(parts) > 1 else "" + return parsed_key, parsed_bucket or bucket + + if object_name.startswith("/"): + path = object_name.lstrip("/") + parts = path.split("/", 1) + parsed_bucket = parts[0] if parts[0] else None + parsed_key = parts[1] if len(parts) > 1 else "" + return parsed_key, parsed_bucket or bucket + + return object_name, bucket + + +def build_s3_url(object_name: str, bucket: Optional[str] = None) -> str: + """ + Build an s3://bucket/key style URL from an object name (or passthrough if already s3://). + """ + if not object_name: + return "" + + if object_name.startswith(S3_URL_PREFIX): + return object_name + + if object_name.startswith("/"): + path = object_name.lstrip("/") + parts = path.split("/", 1) + if len(parts) == 2: + return f"{S3_URL_PREFIX}{parts[0]}/{parts[1]}" + return f"{S3_URL_PREFIX}{parts[0]}/" + + resolved_bucket = bucket or minio_client.default_bucket + if resolved_bucket: + return f"{S3_URL_PREFIX}{resolved_bucket}/{object_name}" + return f"{S3_URL_PREFIX}{object_name}" def generate_object_name(file_name: str, prefix: str = "attachments") -> str: @@ -165,7 +218,8 @@ def get_file_size_from_minio(object_name: str, bucket: Optional[str] = None) -> """ Get file size by object name """ - bucket = bucket or minio_client.storage_config.default_bucket + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) + bucket = bucket or minio_client.default_bucket return minio_client.get_file_size(object_name, bucket) @@ -181,6 +235,7 @@ def file_exists(object_name: str, bucket: Optional[str] = None) -> bool: bool: True if file exists, False otherwise """ try: + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) return minio_client.file_exists(object_name, bucket) except Exception: return False @@ -198,6 +253,8 @@ def copy_file(source_object: str, dest_object: str, bucket: Optional[str] = None Returns: Dict[str, Any]: Result containing success flag and error message (if any) """ + source_object, bucket = _normalize_object_and_bucket(source_object, bucket) + dest_object, bucket = _normalize_object_and_bucket(dest_object, bucket) success, result = minio_client.copy_file(source_object, dest_object, bucket) if success: return {"success": True, "object_name": result} @@ -242,8 +299,9 @@ def delete_file(object_name: str, bucket: Optional[str] = None) -> Dict[str, Any Returns: Dict[str, Any]: Delete result, containing success flag and error message (if any) """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) if not bucket: - bucket = minio_client.storage_config.default_bucket + bucket = minio_client.default_bucket success, result = minio_client.delete_file(object_name, bucket) response = {"success": success, "object_name": object_name} @@ -265,6 +323,7 @@ def get_file_stream(object_name: str, bucket: Optional[str] = None) -> Optional[ Returns: Optional[BinaryIO]: Standard BinaryIO stream object, or None if failed """ + object_name, bucket = _normalize_object_and_bucket(object_name, bucket) success, result = minio_client.get_file_stream(object_name, bucket) if not success: return None diff --git a/backend/database/client.py b/backend/database/client.py index 9b0b97a52..8885ea694 100644 --- a/backend/database/client.py +++ b/backend/database/client.py @@ -89,6 +89,9 @@ def __init__(self): if MinioClient._initialized: return MinioClient._initialized = True + # Explicitly initialize attributes so external callers never hit missing-attribute errors. + self._storage_client = None + self.storage_config = None def _ensure_initialized(self): """Lazily initialize the storage client on first use.""" @@ -108,6 +111,23 @@ def _ensure_initialized(self): return True return False + @property + def default_bucket(self) -> Optional[str]: + """ + Resolve default bucket safely for callers that need bucket info. + Falls back to configured constant when lazy init has not run yet. + """ + try: + self._ensure_initialized() + except Exception: + # Keep this accessor resilient; operational methods can still raise + # detailed storage errors when invoked. + pass + + if getattr(self, "storage_config", None) is not None: + return self.storage_config.default_bucket + return MINIO_DEFAULT_BUCKET + def upload_file( self, file_path: str, diff --git a/backend/database/db_models.py b/backend/database/db_models.py index bc95a5e68..688743343 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -280,6 +280,7 @@ class KnowledgeRecord(TableBase): group_ids = Column(String, doc="Knowledge base group IDs list") ingroup_permission = Column( String(30), doc="In-group permission: EDIT, READ_ONLY, PRIVATE") + is_multimodal = Column(String(1), default="N", doc="Whether it is multimodal. Optional values: Y/N") class TenantConfig(TableBase): diff --git a/backend/database/knowledge_db.py b/backend/database/knowledge_db.py index df42e1888..d69392418 100644 --- a/backend/database/knowledge_db.py +++ b/backend/database/knowledge_db.py @@ -52,6 +52,7 @@ def create_knowledge_record(query: Dict[str, Any]) -> Dict[str, Any]: "knowledge_name": knowledge_name, "group_ids": convert_list_to_string(group_ids) if isinstance(group_ids, list) else group_ids, "ingroup_permission": query.get("ingroup_permission"), + "is_multimodal": 'Y' if query.get("is_multimodal") else 'N' } # For backward compatibility: if caller explicitly provides index_name, @@ -178,6 +179,9 @@ def update_knowledge_record(query: Dict[str, Any]) -> bool: if query.get("group_ids") is not None: record.group_ids = query["group_ids"] + if query.get("is_multimodal"): + record.is_multimodal = 'Y' if query["is_multimodal"] else 'N' + # Update timestamp and user if query.get("user_id"): record.updated_by = query["user_id"] @@ -254,6 +258,11 @@ def get_knowledge_record(query: Optional[Dict[str, Any]] = None) -> Dict[str, An db_query = db_query.filter( KnowledgeRecord.tenant_id == query['tenant_id']) + if 'is_multimodal' in query: + db_query = db_query.filter( + KnowledgeRecord.is_multimodal == query['is_multimodal'] + ) + result = db_query.first() if result: @@ -361,14 +370,25 @@ def get_index_name_by_knowledge_name(knowledge_name: str, tenant_id: str) -> str """ try: with get_db_session() as session: + # First try resolving by user-facing knowledge_name. result = session.query(KnowledgeRecord).filter( KnowledgeRecord.knowledge_name == knowledge_name, KnowledgeRecord.tenant_id == tenant_id, KnowledgeRecord.delete_flag != 'Y' ).first() - if result: return result.index_name + + # Backward/forward compatibility: if caller already passes internal index_name, + # accept it directly by resolving on index_name as well. + index_result = session.query(KnowledgeRecord).filter( + KnowledgeRecord.index_name == knowledge_name, + KnowledgeRecord.tenant_id == tenant_id, + KnowledgeRecord.delete_flag != 'Y' + ).first() + if index_result: + return index_result.index_name + raise ValueError( f"Knowledge base '{knowledge_name}' not found for the current tenant" ) diff --git a/backend/database/model_management_db.py b/backend/database/model_management_db.py index cb1c6c69f..61753f52f 100644 --- a/backend/database/model_management_db.py +++ b/backend/database/model_management_db.py @@ -170,7 +170,7 @@ def get_model_records(filters: Optional[Dict[str, Any]], tenant_id: str) -> List return result_list -def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dict[str, Any]]: +def get_model_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[Dict[str, Any]]: """ Get a model record by display name @@ -179,6 +179,11 @@ def get_model_by_display_name(display_name: str, tenant_id: str) -> Optional[Dic tenant_id: """ filters = {'display_name': display_name} + + if model_type in ["multiEmbedding", "multi_embedding"]: + filters['model_type'] = "multi_embedding" + elif model_type == "embedding": + filters['model_type'] = "embedding" records = get_model_records(filters, tenant_id) if not records: @@ -203,7 +208,7 @@ def get_models_by_display_name(display_name: str, tenant_id: str) -> List[Dict[s return get_model_records(filters, tenant_id) -def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[int]: +def get_model_id_by_display_name(display_name: str, tenant_id: str, model_type: str = None) -> Optional[int]: """ Get a model ID by display name @@ -214,7 +219,7 @@ def get_model_id_by_display_name(display_name: str, tenant_id: str) -> Optional[ Returns: Optional[int]: Model ID """ - model = get_model_by_display_name(display_name, tenant_id) + model = get_model_by_display_name(display_name, tenant_id, model_type) return model["model_id"] if model else None diff --git a/backend/services/config_sync_service.py b/backend/services/config_sync_service.py index 9fe50813a..c484ca23f 100644 --- a/backend/services/config_sync_service.py +++ b/backend/services/config_sync_service.py @@ -99,7 +99,7 @@ async def save_config_impl(config, tenant_id, user_id): config_key = get_env_key(model_type) + "_ID" model_id = get_model_id_by_display_name( - model_display_name, tenant_id) + model_display_name, tenant_id, model_type=model_type) handle_model_config(tenant_id, user_id, config_key, model_id, tenant_config_dict) diff --git a/backend/services/data_process_service.py b/backend/services/data_process_service.py index 2b222a584..17e64a697 100644 --- a/backend/services/data_process_service.py +++ b/backend/services/data_process_service.py @@ -255,6 +255,17 @@ async def load_image(self, image_url: str) -> Optional[Image.Image]: async def _load_image(self, session: aiohttp.ClientSession, path: str) -> Optional[Image.Image]: """Internal method to load an image from various sources""" try: + if path.startswith('s3://'): + # Fetch from MinIO using s3://bucket/key + file_stream = get_file_stream(object_name=path) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {path}") + file_data = file_stream.read() + image_based64_str = base64.b64encode( + file_data).decode('utf-8') + path = f"data:image/jpeg;base64,{image_based64_str}" + # Check if input is base64 encoded if path.startswith('data:image'): # Extract the base64 data after the comma @@ -463,6 +474,8 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B chunking_strategy = source_config.get('chunking_strategy') index_name = source_config.get('index_name') original_filename = source_config.get('original_filename') + embedding_model_id = source_config.get('embedding_model_id') + tenant_id = source_config.get('tenant_id') # Validate required fields if not source: @@ -481,7 +494,9 @@ async def create_batch_tasks_impl(self, authorization: Optional[str], request: B source_type=source_type, chunking_strategy=chunking_strategy, index_name=index_name, - original_filename=original_filename + original_filename=original_filename, + embedding_model_id=embedding_model_id, + tenant_id=tenant_id ).set(queue='process_q'), forward.s( index_name=index_name, @@ -559,7 +574,7 @@ async def process_uploaded_text_file(self, file_content: bytes, filename: str, c } async def convert_office_to_pdf_impl(self, object_name: str, pdf_object_name: str) -> None: - """Full conversion pipeline: download → convert → upload → validate → cleanup. + """Full conversion pipeline: download -> convert -> upload -> validate -> cleanup. All five steps run inside data-process so that LibreOffice only needs to be installed in this container. diff --git a/backend/services/datamate_service.py b/backend/services/datamate_service.py index 776e0eb1d..26e777eba 100644 --- a/backend/services/datamate_service.py +++ b/backend/services/datamate_service.py @@ -51,7 +51,8 @@ async def _create_datamate_knowledge_records(knowledge_base_ids: List[str], "tenant_id": tenant_id, "user_id": user_id, # Use datamate as embedding model name - "embedding_model_name": embedding_model_names[i] + "embedding_model_name": embedding_model_names[i], + "is_multimodal": False, } # Run synchronous database operation in executor to avoid blocking diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index 9214a1ffa..5b8e27f07 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -128,10 +128,10 @@ async def _perform_connectivity_check( return connectivity -async def check_model_connectivity(display_name: str, tenant_id: str) -> dict: +async def check_model_connectivity(display_name: str, tenant_id: str, model_type: str = None) -> dict: try: # Query the database using display_name and tenant context from app layer - model = get_model_by_display_name(display_name, tenant_id=tenant_id) + model = get_model_by_display_name(display_name, tenant_id=tenant_id, model_type=model_type) if not model: raise LookupError(f"Model configuration not found for {display_name}") diff --git a/backend/services/tool_configuration_service.py b/backend/services/tool_configuration_service.py index d7240db26..53a131013 100644 --- a/backend/services/tool_configuration_service.py +++ b/backend/services/tool_configuration_service.py @@ -152,6 +152,10 @@ def get_local_tools() -> List[ToolInfo]: else: param_info["default"] = param.default.default param_info["optional"] = True + if getattr(param.default, "json_schema_extra", None): + optional_override = param.default.json_schema_extra.get("optional") + if optional_override is not None: + param_info["optional"] = optional_override init_params_list.append(param_info) @@ -682,6 +686,8 @@ def _validate_local_tool( if not tool_class: raise NotFoundException(f"Tool class not found for {tool_name}") + runtime_inputs = dict(inputs or {}) + # Parse instantiation parameters first instantiation_params = params or {} # Get signature and extract default values for all parameters @@ -704,7 +710,14 @@ def _validate_local_tool( instantiation_params[param_name] = param.default if tool_name == "knowledge_base_search": - embedding_model = get_embedding_model(tenant_id=tenant_id) + # Compatibility: historically index_names might be sent in runtime inputs. + # knowledge_base_search now treats index_names as init params (tool config), + # not forward() inputs. + if "index_names" in runtime_inputs and "index_names" not in instantiation_params: + instantiation_params["index_names"] = runtime_inputs.pop("index_names") + + is_multimodal = instantiation_params.pop("multimodal", False) + embedding_model = get_embedding_model(tenant_id=tenant_id, is_multimodal=is_multimodal) vdb_core = get_vector_db_core() # Get rerank configuration @@ -760,7 +773,18 @@ def _validate_local_tool( else: tool_instance = tool_class(**instantiation_params) - result = tool_instance.forward(**(inputs or {})) + # Only pass declared runtime inputs to forward() to avoid unexpected kwargs. + declared_inputs = getattr(tool_class, "inputs", {}) or {} + allowed_input_names = ( + set(declared_inputs.keys()) if isinstance(declared_inputs, dict) else set() + ) + filtered_runtime_inputs = ( + {k: v for k, v in runtime_inputs.items() if k in allowed_input_names} + if allowed_input_names + else runtime_inputs + ) + + result = tool_instance.forward(**filtered_runtime_inputs) return result except Exception as e: logger.error(f"Local tool validation failed for {tool_name}: {e}") diff --git a/backend/services/vectordatabase_service.py b/backend/services/vectordatabase_service.py index 5639103de..b75d8183c 100644 --- a/backend/services/vectordatabase_service.py +++ b/backend/services/vectordatabase_service.py @@ -28,7 +28,7 @@ from consts.const import DATAMATE_URL, ES_API_KEY, ES_HOST, LANGUAGE, VectorDatabaseType, IS_SPEED_MODE, PERMISSION_EDIT, PERMISSION_READ from consts.model import ChunkCreateRequest, ChunkUpdateRequest -from database.attachment_db import delete_file +from database.attachment_db import delete_file, get_file_stream from database.knowledge_db import ( create_knowledge_record, delete_knowledge_record, @@ -176,7 +176,80 @@ def check_knowledge_base_exist_impl(knowledge_name: str, vdb_core: VectorDatabas return {"status": "available"} -def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): +def _build_embedding_from_config(model_config: Dict[str, Any]) -> Optional[BaseEmbedding]: + model_type = model_config.get("model_type", "") + if model_type == "embedding": + return OpenAICompatibleEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + if model_type == "multi_embedding": + return JinaEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + return None + + +def _find_model_record( + tenant_id: str, + is_multimodal: bool, + model_name: str, +) -> Optional[Dict[str, Any]]: + model_type = "multi_embedding" if is_multimodal else "embedding" + models = get_model_records({"model_type": model_type}, tenant_id) + for model in models: + model_display_name = ( + f"{model.get('model_repo')}/{model['model_name']}" + if model.get("model_repo") + else model["model_name"] + ) + if model_display_name == model_name: + return model + return None + + +def _build_embedding_from_record( + model_record: Dict[str, Any], + is_multimodal: bool, +) -> BaseEmbedding: + model_config = { + "model_repo": model_record.get("model_repo", ""), + "model_name": model_record["model_name"], + "api_key": model_record.get("api_key", ""), + "base_url": model_record.get("base_url", ""), + "model_type": "embedding", + "max_tokens": model_record.get("max_tokens", 1024), + "ssl_verify": model_record.get("ssl_verify", True), + } + if not is_multimodal: + return OpenAICompatibleEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + return JinaEmbedding( + api_key=model_config.get("api_key", ""), + base_url=model_config.get("base_url", ""), + model_name=get_model_name_from_config(model_config) or "", + embedding_dim=model_config.get("max_tokens", 1024), + ssl_verify=model_config.get("ssl_verify", True), + ) + +def get_embedding_model( + tenant_id: str, + is_multimodal: bool = False, + model_name: Optional[str] = None, + strict_model_name: bool = False, +): """ Get the embedding model for the tenant, optionally using a specific model name. @@ -188,58 +261,50 @@ def get_embedding_model(tenant_id: str, model_name: Optional[str] = None): Returns: Embedding model instance or None """ + if model_name is None and (isinstance(is_multimodal, str) or is_multimodal is None): + model_name = is_multimodal + is_multimodal = False # If model_name is provided, try to find it in the tenant's models if model_name: try: - models = get_model_records({"model_type": "embedding"}, tenant_id) - for model in models: - model_display_name = model.get("model_repo") + "/" + model["model_name"] if model.get("model_repo") else model["model_name"] - if model_display_name == model_name: - # Found the model, create embedding instance - model_config = { - "model_repo": model.get("model_repo", ""), - "model_name": model["model_name"], - "api_key": model.get("api_key", ""), - "base_url": model.get("base_url", ""), - "model_type": "embedding", - "max_tokens": model.get("max_tokens", 1024), - "ssl_verify": model.get("ssl_verify", True), - } - return OpenAICompatibleEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), - ) + model_record = _find_model_record( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=model_name, + ) + if model_record: + return _build_embedding_from_record( + model_record=model_record, + is_multimodal=is_multimodal, + ) except Exception as e: logger.warning(f"Failed to get embedding model by name {model_name}: {e}") + if strict_model_name: + raise ValueError( + f"Embedding model '{model_name}' is not configured for current tenant" + ) # Fall back to default embedding model (current behavior) model_config = tenant_config_manager.get_model_config( - key="EMBEDDING_ID", tenant_id=tenant_id) - - model_type = model_config.get("model_type", "") - - if model_type == "embedding": - # Get the es core - return OpenAICompatibleEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), - ) - elif model_type == "multi_embedding": - return JinaEmbedding( - api_key=model_config.get("api_key", ""), - base_url=model_config.get("base_url", ""), - model_name=get_model_name_from_config(model_config) or "", - embedding_dim=model_config.get("max_tokens", 1024), - ssl_verify=model_config.get("ssl_verify", True), + key="MULTI_EMBEDDING_ID" if is_multimodal else "EMBEDDING_ID", + tenant_id=tenant_id, + ) + return _build_embedding_from_config(model_config) + + +def _resolve_embedding_model( + tenant_id: str, + is_multimodal: bool, + embedding_model_name: Optional[str], +) -> Optional[BaseEmbedding]: + if embedding_model_name: + return get_embedding_model( + tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + strict_model_name=True, ) - else: - return None + return get_embedding_model(tenant_id, is_multimodal=is_multimodal) def get_rerank_model(tenant_id: str, model_name: Optional[str] = None): @@ -406,6 +471,7 @@ async def full_delete_knowledge_base(index_name: str, vdb_core: VectorDatabaseCo @staticmethod def create_index( + embedding_model: BaseEmbedding, index_name: str = Path(..., description="Name of the index to create"), embedding_dim: Optional[int] = Query( @@ -419,15 +485,24 @@ def create_index( try: if vdb_core.check_index_exists(index_name): raise Exception(f"Index {index_name} already exists") - embedding_model = get_embedding_model(tenant_id) + if not embedding_model: + embedding_model = get_embedding_model(tenant_id) success = vdb_core.create_index(index_name, embedding_dim=embedding_dim or ( embedding_model.embedding_dim if embedding_model else 1024)) if not success: raise Exception(f"Failed to create index {index_name}") - knowledge_data = {"index_name": index_name, - "created_by": user_id, - "tenant_id": tenant_id, - "embedding_model_name": embedding_model.model} + is_multimodal = ( + True + if embedding_model and getattr(embedding_model, "model_type", None) == "multimodal" + else False + ) + knowledge_data = { + "index_name": index_name, + "created_by": user_id, + "tenant_id": tenant_id, + "embedding_model_name": embedding_model.model, + "is_multimodal": is_multimodal, + } create_knowledge_record(knowledge_data) return {"status": "success", "message": f"Index {index_name} created successfully"} except Exception as e: @@ -443,6 +518,7 @@ def create_knowledge_base( ingroup_permission: Optional[str] = None, group_ids: Optional[List[int]] = None, embedding_model_name: Optional[str] = None, + is_multimodal: bool = False, ): """ Create a new knowledge base with a user-facing name and an internal Elasticsearch index name. @@ -468,7 +544,18 @@ def create_knowledge_base( """ try: # Get embedding model - use user-selected model if provided, otherwise use tenant default - embedding_model = get_embedding_model(tenant_id, embedding_model_name) + embedding_model = get_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + ) + + # If caller did not provide an explicit flag, infer multimodal from model metadata. + resolved_is_multimodal = is_multimodal or ( + True + if embedding_model and getattr(embedding_model, "model_type", None) == "multimodal" + else False + ) # Determine the embedding model name to save: use user-provided name if available, # otherwise use the model's display name @@ -483,6 +570,7 @@ def create_knowledge_base( "user_id": user_id, "tenant_id": tenant_id, "embedding_model_name": saved_embedding_model_name, + "is_multimodal": resolved_is_multimodal, } # Add group permission and group IDs if provided @@ -519,6 +607,7 @@ def update_knowledge_base( knowledge_name: Optional[str] = None, ingroup_permission: Optional[str] = None, group_ids: Optional[List[int]] = None, + is_multimodal: bool = False, tenant_id: Optional[str] = None, user_id: Optional[str] = None, ) -> bool: @@ -549,6 +638,7 @@ def update_knowledge_base( update_data = { "index_name": index_name, "updated_by": user_id, + "is_multimodal": is_multimodal, } if knowledge_name is not None: @@ -784,6 +874,7 @@ def list_indices( # knowledge source and ingroup permission from DB record "knowledge_sources": record["knowledge_sources"], "ingroup_permission": record["ingroup_permission"], + "is_multimodal": record.get("is_multimodal"), "tenant_id": record.get("tenant_id"), # Update time for sorting and display "update_time": record.get("update_time"), @@ -882,12 +973,27 @@ def index_documents( "author": author, "date": date, "content": text, - "process_source": "Unstructured", + "process_source": metadata.get("process_source", "Unstructured"), "file_size": file_size, "create_time": create_time, "languages": metadata.get("languages", []), "embedding_model_name": embedding_model_name } + + image_url = metadata.get("image_url", "") + if len(image_url) > 0: + # Fetch image bytes from MinIO (supports s3://bucket/key or /bucket/key) + try: + file_stream = get_file_stream( + object_name=image_url) + if file_stream is None: + raise FileNotFoundError( + f"Unable to fetch file from URL: {image_url}") + document["image_bytes"] = file_stream.read() + except Exception as e: + logger.error( + f"Failed to fetch file from {image_url}: {e}") + raise documents.append(document) @@ -908,8 +1014,9 @@ def index_documents( 'tenant_id') if knowledge_record else None if tenant_id: + model_type = "EMBEDDING_ID" if embedding_model.model_type == "text" else "MULTI_EMBEDDING_ID" model_config = tenant_config_manager.get_model_config( - key="EMBEDDING_ID", tenant_id=tenant_id) + key=model_type, tenant_id=tenant_id) embedding_batch_size = model_config.get("chunk_batch", 10) if embedding_batch_size is None: embedding_batch_size = 10 @@ -1552,6 +1659,7 @@ def create_chunk( try: # Get knowledge base's embedding model name embedding_model_name = None + is_multimodal = False if tenant_id: try: knowledge_record = get_knowledge_record({ @@ -1559,6 +1667,11 @@ def create_chunk( "tenant_id": tenant_id }) embedding_model_name = knowledge_record.get("embedding_model_name") if knowledge_record else None + is_multimodal = ( + True + if knowledge_record and knowledge_record.get("is_multimodal") == "Y" + else False + ) except Exception as e: logger.warning(f"Failed to get embedding model name for index {index_name}: {e}") @@ -1566,7 +1679,16 @@ def create_chunk( embedding_vector = None if chunk_request.content: try: - embedding_model = get_embedding_model(tenant_id, embedding_model_name) if tenant_id else None + embedding_model = ( + get_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + strict_model_name=bool(embedding_model_name), + ) + if tenant_id + else None + ) if embedding_model: embeddings = embedding_model.get_embeddings(chunk_request.content) if embeddings and len(embeddings) > 0: @@ -1577,6 +1699,8 @@ def create_chunk( else: logger.warning(f"No embedding model available for index {index_name}") except Exception as e: + if embedding_model_name: + raise logger.warning(f"Failed to generate embedding for chunk: {e}") # Build chunk payload @@ -1617,6 +1741,7 @@ def update_chunk( chunk_request: ChunkUpdateRequest, vdb_core: VectorDatabaseCore = Depends(get_vector_db_core), user_id: Optional[str] = None, + tenant_id: Optional[str] = None, ): """ Update a chunk document. @@ -1625,6 +1750,37 @@ def update_chunk( update_fields = chunk_request.dict( exclude_unset=True, exclude={"metadata"}) metadata = chunk_request.metadata or {} + + if "content" in update_fields and update_fields.get("content"): + embedding_model_name = None + is_multimodal = False + if tenant_id: + knowledge_record = get_knowledge_record( + {"index_name": index_name, "tenant_id": tenant_id} + ) + embedding_model_name = ( + knowledge_record.get("embedding_model_name") + if knowledge_record + else None + ) + is_multimodal = bool( + knowledge_record and knowledge_record.get("is_multimodal") == "Y" + ) + + embedding_model = get_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + strict_model_name=bool(embedding_model_name), + ) + embeddings = embedding_model.get_embeddings( + update_fields["content"] + ) + if embeddings and len(embeddings) > 0: + update_fields["embedding"] = embeddings[0] + if embedding_model_name: + update_fields["embedding_model_name"] = embedding_model_name + update_payload = ElasticSearchService._build_chunk_payload( base_fields={ **update_fields, @@ -1700,7 +1856,23 @@ def search_hybrid( if weight_accurate < 0 or weight_accurate > 1: raise ValueError("weight_accurate must be between 0 and 1") - embedding_model = get_embedding_model(tenant_id) + embedding_model_name = None + is_multimodal = False + for index_name in index_names: + knowledge_record = get_knowledge_record( + {"index_name": index_name, "tenant_id": tenant_id} + ) + if knowledge_record: + embedding_model_name = knowledge_record.get("embedding_model_name") + is_multimodal = knowledge_record.get("is_multimodal") == "Y" + break + + embedding_model = get_embedding_model( + tenant_id=tenant_id, + is_multimodal=is_multimodal, + model_name=embedding_model_name, + strict_model_name=bool(embedding_model_name), + ) if not embedding_model: raise ValueError( "No embedding model configured for the current tenant") diff --git a/backend/utils/file_management_utils.py b/backend/utils/file_management_utils.py index 7d31a74bb..f98e5be9f 100644 --- a/backend/utils/file_management_utils.py +++ b/backend/utils/file_management_utils.py @@ -15,7 +15,6 @@ from consts.model import ProcessParams from database.attachment_db import get_file_size_from_minio from utils.auth_utils import get_current_user_id -from utils.config_utils import tenant_config_manager logger = logging.getLogger("file_management_utils") @@ -45,18 +44,13 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams) if not files: return None - # Get chunking size according to the embedding model - embedding_model_id = None + # Get tenant_id from authorization for downstream task processing + embedding_model_id = process_params.model_id tenant_id = None try: _, tenant_id = get_current_user_id(process_params.authorization) - # Get embedding model ID from tenant config - tenant_config = tenant_config_manager.load_config(tenant_id) - embedding_model_id_str = tenant_config.get("EMBEDDING_ID") if tenant_config else None - if embedding_model_id_str: - embedding_model_id = int(embedding_model_id_str) except Exception as e: - logger.warning(f"Failed to get embedding model ID for tenant: {e}") + logger.warning(f"Failed to get tenant_id from authorization: {e}") # Build headers with authorization headers = { @@ -105,6 +99,7 @@ async def trigger_data_process(files: List[dict], process_params: ProcessParams) "index_name": process_params.index_name, "original_filename": file_details.get("filename"), "embedding_model_id": embedding_model_id, + "is_multimodal": is_multimodal, "tenant_id": tenant_id } sources.append(source) diff --git a/docker/deploy.sh b/docker/deploy.sh index e30e6e75a..233c14604 100755 --- a/docker/deploy.sh +++ b/docker/deploy.sh @@ -17,6 +17,7 @@ DEPLOY_OPTIONS_FILE="$SCRIPT_DIR/deploy.options" MODE_CHOICE_SAVED="" VERSION_CHOICE_SAVED="" IS_MAINLAND_SAVED="" +DOWNLOAD_MODELS="N" ENABLE_TERMINAL_SAVED="N" TERMINAL_MOUNT_DIR_SAVED="${TERMINAL_MOUNT_DIR:-}" APP_VERSION="" @@ -79,6 +80,58 @@ is_windows_env() { return 1 } +detect_os_type() { + # Return: windows | mac | linux | unknown + local os_name + os_name=$(uname -s 2>/dev/null | tr '[:upper:]' '[:lower:]') + case "$os_name" in + mingw*|msys*|cygwin*) + echo "windows" + ;; + darwin*) + echo "mac" + ;; + linux*) + echo "linux" + ;; + *) + echo "unknown" + ;; + esac + return 0 +} + +format_path_for_env() { + # Convert path to OS-specific format for .env values + local input_path="$1" + local os_type + os_type=$(detect_os_type) + + if [[ "$os_type" = "windows" ]]; then + if command -v cygpath >/dev/null 2>&1; then + cygpath -w "$input_path" + return 0 + fi + + if [[ "$input_path" =~ ^/([a-zA-Z])/(.*)$ ]]; then + local drive="${BASH_REMATCH[1]}" + local rest="${BASH_REMATCH[2]}" + rest="${rest//\//\\}" + printf "%s:\\%s" "$(echo "$drive" | tr '[:lower:]' '[:upper:]')" "$rest" + return 0 + fi + fi + + printf "%s" "$input_path" +} + +escape_backslashes() { + # Escape backslashes for safe writing into .env or JSON + local input_path="$1" + printf "%s" "$input_path" | sed 's/\\/\\\\/g' + return 0 +} + is_port_in_use() { # Check if a TCP port is already in use (Linux/macOS/Windows Git Bash) local port="$1" @@ -266,6 +319,7 @@ persist_deploy_options() { echo "MODE_CHOICE=\"${MODE_CHOICE_SAVED}\"" echo "VERSION_CHOICE=\"${VERSION_CHOICE_SAVED}\"" echo "IS_MAINLAND=\"${IS_MAINLAND_SAVED}\"" + echo "DOWNLOAD_MODELS=\"${DOWNLOAD_MODELS}\"" echo "ENABLE_TERMINAL=\"${ENABLE_TERMINAL_SAVED}\"" echo "TERMINAL_MOUNT_DIR=\"${TERMINAL_MOUNT_DIR_SAVED}\"" } > "$DEPLOY_OPTIONS_FILE" @@ -528,6 +582,229 @@ select_deployment_mode() { echo "" } + +# Model download selection +select_model_download() { + echo "" + + local input_choice="" + read -r -p "Do you want to download AI model files (table-transformer and yolox)? [Y/N] (default: N): " input_choice + echo "" + + if [[ $input_choice =~ ^[Yy]$ ]]; then + DOWNLOAD_MODELS="Y" + echo "INFO: Model download will be performed." + else + DOWNLOAD_MODELS="N" + echo "INFO: Skipping model download." + fi + echo "----------------------------------------" + echo "" + return 0 +} + +# kerry + +download_and_config_models() { + if [[ "$DOWNLOAD_MODELS" != "Y" ]]; then + echo "INFO: Model download skipped by user choice." + return 0 + fi + + echo "INFO: Downloading AI model files (this may take a while)..." + + local env_file_dir="$SCRIPT_DIR" + local env_file_path="$env_file_dir/.env" + local original_dir="$(pwd)" + + MODEL_ROOT="$ROOT_DIR/model" + mkdir -p "$MODEL_ROOT" + echo "INFO: Model directory: $MODEL_ROOT" + + export HF_ENDPOINT="https://hf-mirror.com" + + command -v git >/dev/null || { echo "ERROR: git is required but not found." >&2; return 1; } + + # ========================================== + # 1. Table Transformer (table-structure recognition) + echo "INFO: Downloading table-transformer-structure-recognition..." + + TT_MODEL_DIR_NAME="table-transformer-structure-recognition" + TT_MODEL_DIR_PATH="$MODEL_ROOT/$TT_MODEL_DIR_NAME" + MODEL_SAFETENSORS_FILE="model.safetensors" + TT_MODEL_FILE_CHECK="$TT_MODEL_DIR_PATH/$MODEL_SAFETENSORS_FILE" + + cd "$MODEL_ROOT" || return 1 + + if [[ -d "$TT_MODEL_DIR_PATH" ]] && [[ -f "$TT_MODEL_FILE_CHECK" ]]; then + FILE_SIZE=$(stat -c%s "$TT_MODEL_FILE_CHECK" 2>/dev/null || stat -f%z "$TT_MODEL_FILE_CHECK" 2>/dev/null) + if [[ "$FILE_SIZE" -gt 1000000 ]]; then + echo "INFO: Table Transformer already exists." + else + echo "WARN: Existing model file looks incomplete, re-downloading..." + rm -rf "$TT_MODEL_DIR_NAME" + fi + fi + + if [[ ! -f "$TT_MODEL_FILE_CHECK" ]]; then + if [[ -d "$TT_MODEL_DIR_NAME" ]]; then + echo "WARN: Removing existing directory before re-download..." + rm -rf "$TT_MODEL_DIR_NAME" + fi + + echo "INFO: Step 1/2: Clone repo (skip LFS files)..." + if ! GIT_LFS_SKIP_SMUDGE=1 git clone "$HF_ENDPOINT/microsoft/$TT_MODEL_DIR_NAME" "$TT_MODEL_DIR_NAME"; then + echo "ERROR: Failed to clone repository." >&2 + cd "$original_dir" + return 1 + fi + + cd "$TT_MODEL_DIR_NAME" || return 1 + + echo "INFO: Step 2/2: Download model.safetensors..." + LARGE_FILE_URL="$HF_ENDPOINT/microsoft/$TT_MODEL_DIR_NAME/resolve/main/$MODEL_SAFETENSORS_FILE" + + if command -v curl &> /dev/null; then + curl -L -o "$MODEL_SAFETENSORS_FILE" "$LARGE_FILE_URL" --progress-bar + elif command -v wget &> /dev/null; then + wget "$LARGE_FILE_URL" -O "$MODEL_SAFETENSORS_FILE" + else + echo "ERROR: curl or wget is required to download model files." >&2 + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$original_dir"; return 1 + fi + + if [[ ! -f "$MODEL_SAFETENSORS_FILE" ]]; then + echo "ERROR: $MODEL_SAFETENSORS_FILE download failed." >&2 + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$original_dir"; return 1 + fi + + FILE_SIZE=$(stat -c%s "$MODEL_SAFETENSORS_FILE" 2>/dev/null || stat -f%z "$MODEL_SAFETENSORS_FILE" 2>/dev/null) + if [[ "$FILE_SIZE" -lt 1000000 ]]; then + echo "ERROR: $MODEL_SAFETENSORS_FILE seems too small (size: $FILE_SIZE bytes)." >&2 + cd "$MODEL_ROOT"; rm -rf "$TT_MODEL_DIR_NAME"; cd "$original_dir"; return 1 + fi + + echo "INFO: $MODEL_SAFETENSORS_FILE downloaded (size: $(du -h "$MODEL_SAFETENSORS_FILE" | cut -f1))" + cd "$MODEL_ROOT" + fi + + echo "INFO: Table Transformer OK" + + # ========================================== + # 2. YOLOX (layout detection model) + echo "INFO: Downloading yolox_l0.05.onnx" + + YOLOX_MODEL_FILE="$MODEL_ROOT/yolox_l0.05.onnx" + MIN_YOLOX_SIZE=50000000 + + NEED_DOWNLOAD=false + + if [[ -f "$YOLOX_MODEL_FILE" ]]; then + CURRENT_SIZE=$(stat -c%s "$YOLOX_MODEL_FILE" 2>/dev/null || stat -f%z "$YOLOX_MODEL_FILE" 2>/dev/null) + if [[ "$CURRENT_SIZE" -lt "$MIN_YOLOX_SIZE" ]]; then + echo "WARN: Existing YOLOX file looks incomplete (size: $(numfmt --to=iec-i --suffix=B $CURRENT_SIZE 2>/dev/null || echo $CURRENT_SIZE)). Re-downloading..." + NEED_DOWNLOAD=true + else + echo "INFO: YOLOX already exists." + fi + else + NEED_DOWNLOAD=true + fi + + if [[ "$NEED_DOWNLOAD" = true ]]; then + ONNX_URL="$HF_ENDPOINT/unstructuredio/yolo_x_layout/resolve/main/yolox_l0.05.onnx" + + if command -v curl &> /dev/null; then + echo "INFO: Downloading with curl (supports resume -C -)..." + if curl -L -C - -o "$YOLOX_MODEL_FILE" "$ONNX_URL" --progress-bar; then + echo "INFO: curl download completed" + else + echo "ERROR: curl download failed." >&2 + cd "$original_dir" + return 1 + fi + elif command -v wget &> /dev/null; then + echo "INFO: Downloading with wget (supports resume -c)..." + wget -c "$ONNX_URL" -O "$YOLOX_MODEL_FILE" + else + echo "ERROR: curl or wget is required to download model files." >&2 + cd "$original_dir" + return 1 + fi + + if [[ -f "$YOLOX_MODEL_FILE" ]]; then + FINAL_SIZE=$(stat -c%s "$YOLOX_MODEL_FILE" 2>/dev/null || stat -f%z "$YOLOX_MODEL_FILE" 2>/dev/null) + if [[ "$FINAL_SIZE" -lt "$MIN_YOLOX_SIZE" ]]; then + echo "ERROR: YOLOX file seems too small (size: $FINAL_SIZE bytes)." >&2 + cd "$original_dir" + return 1 + else + echo "INFO: YOLOX downloaded (size: $(numfmt --to=iec-i --suffix=B $FINAL_SIZE 2>/dev/null || echo $FINAL_SIZE))" + fi + else + echo "ERROR: YOLOX download failed: file not found." >&2 + cd "$original_dir" + return 1 + fi + fi + + echo "INFO: YOLOX OK" + + # ========================================== + # 3. config.json + CONFIG_FILE="$MODEL_ROOT/config.json" + YOLOX_ABS_PATH=$(cd "$(dirname "$YOLOX_MODEL_FILE")" && pwd)/$(basename "$YOLOX_MODEL_FILE") + YOLOX_OS_PATH=$(format_path_for_env "$YOLOX_ABS_PATH") + YOLOX_CONFIG_PATH=$(escape_backslashes "$YOLOX_OS_PATH") + + cat > "$CONFIG_FILE" < /dev/null; then + update_env_var "TABLE_TRANSFORMER_MODEL_PATH" "$TT_MODEL_DIR_ENV_PATH" + update_env_var "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH" "$CONFIG_FILE_ENV_PATH" + else + sed -i.bak "/^TABLE_TRANSFORMER_MODEL_PATH=/d" "$env_file_path" 2>/dev/null || true + echo "TABLE_TRANSFORMER_MODEL_PATH="$TT_MODEL_DIR_ENV_PATH"" >> "$env_file_path" + + sed -i.bak "/^UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH=/d" "$env_file_path" 2>/dev/null || true + echo "UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH="$CONFIG_FILE_ENV_PATH"" >> "$env_file_path" + rm -f "$env_file_path.bak" 2>/dev/null + fi + + echo "INFO: Environment file updated" + cd "$original_dir" +} + clean() { export MINIO_ACCESS_KEY= export MINIO_SECRET_KEY= @@ -600,6 +877,13 @@ prepare_directory_and_data() { create_dir_with_permission "$ROOT_DIR/minio" 775 create_dir_with_permission "$ROOT_DIR/redis" 775 + echo "📦 Check the status of model configuration..." + download_and_config_models || { + echo "⚠️ A warning occurred during the model configuration step, but subsequent deployment will proceed..." + # Do not exit here; the user may choose N or prefer to continue after a download failure. + } + echo "" + cp -rn volumes $ROOT_DIR chmod -R 775 $ROOT_DIR/volumes echo " 📁 Directory $ROOT_DIR/volumes has been created and permissions set to 775." @@ -1057,6 +1341,8 @@ main_deploy() { select_terminal_tool || { echo "❌ Terminal tool container configuration failed"; exit 1; } choose_image_env || { echo "❌ Image environment setup failed"; exit 1; } + select_model_download || { echo "❌ Model download failed"; exit 1;} + # Set NEXENT_MCP_DOCKER_IMAGE in .env file if [ -n "${NEXENT_MCP_DOCKER_IMAGE:-}" ]; then update_env_var "NEXENT_MCP_DOCKER_IMAGE" "${NEXENT_MCP_DOCKER_IMAGE}" @@ -1142,7 +1428,7 @@ docker_compose_command="" case $version_type in "v1") echo "Detected Docker Compose V1, version: $version_number" - # The version ​​v1.28.0​​ is the minimum requirement in Docker Compose v1 that explicitly supports interpolation syntax with default values like ${VAR:-default} + # The version 1.28.0 is the minimum requirement in Docker Compose v1 for default interpolation syntax. if [[ $version_number < "1.28.0" ]]; then echo "Warning: V1 version is too old, consider upgrading to V2" exit 1 diff --git a/docker/init.sql b/docker/init.sql index 26c345b69..6d274c8e1 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -213,6 +213,7 @@ CREATE TABLE IF NOT EXISTS "knowledge_record_t" ( "embedding_model_name" varchar(200) COLLATE "pg_catalog"."default", "group_ids" varchar, "ingroup_permission" varchar(30), + "is_multimodal" varchar(1) COLLATE "pg_catalog"."default" DEFAULT 'N'::character varying, "create_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP, "update_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP, "delete_flag" varchar(1) COLLATE "pg_catalog"."default" DEFAULT 'N'::character varying, @@ -230,6 +231,7 @@ COMMENT ON COLUMN "knowledge_record_t"."knowledge_sources" IS 'Knowledge base so COMMENT ON COLUMN "knowledge_record_t"."embedding_model_name" IS 'Embedding model name, used to record the embedding model used by the knowledge base'; COMMENT ON COLUMN "knowledge_record_t"."group_ids" IS 'Knowledge base group IDs list'; COMMENT ON COLUMN "knowledge_record_t"."ingroup_permission" IS 'In-group permission: EDIT, READ_ONLY, PRIVATE'; +COMMENT ON COLUMN "knowledge_record_t"."is_multimodal" IS 'whether it is multimodal'; COMMENT ON COLUMN "knowledge_record_t"."create_time" IS 'Creation time, audit field'; COMMENT ON COLUMN "knowledge_record_t"."update_time" IS 'Update time, audit field'; COMMENT ON COLUMN "knowledge_record_t"."delete_flag" IS 'When deleted by user frontend, delete flag will be set to true, achieving soft delete effect. Optional values Y/N'; diff --git a/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql b/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql new file mode 100644 index 000000000..d5b14bfbb --- /dev/null +++ b/docker/sql/v1.8.1_0306_add_is_multimodal_to_knowledge_record_t.sql @@ -0,0 +1,5 @@ +-- Add is_multimodal column to knowledge_record_t table +ALTER TABLE nexent.knowledge_record_t +ADD COLUMN IF NOT EXISTS is_multimodal varchar(1) DEFAULT 'N'; + +COMMENT ON COLUMN nexent.knowledge_record_t.is_multimodal IS 'whether it is multimodal'; diff --git a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx index c97536b92..481d71920 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/ToolManagement.tsx @@ -100,7 +100,8 @@ export default function ToolManagement({ // Use tool list hook for data management const { availableTools } = useToolList(); - const { isVlmAvailable, isEmbeddingAvailable } = useConfig(); + const { isVlmAvailable, isEmbeddingAvailable, isMultiEmbeddingAvailable } = useConfig(); + const isEmbeddingOrMultiAvailable = isEmbeddingAvailable || isMultiEmbeddingAvailable; // Prefetch knowledge bases for KB tools const { prefetchKnowledgeBases } = usePrefetchKnowledgeBases(); @@ -363,7 +364,10 @@ export default function ToolManagement({ tool.id ); const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( + tool.name, + isEmbeddingOrMultiAvailable + ); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly @@ -468,7 +472,10 @@ export default function ToolManagement({ {group.tools.map((tool) => { const isSelected = originalSelectedToolIdsSet.has(tool.id); const isDisabledDueToVlm = isToolDisabledDueToVlm(tool.name, isVlmAvailable); - const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding(tool.name, isEmbeddingAvailable); + const isDisabledDueToEmbedding = isToolDisabledDueToEmbedding( + tool.name, + isEmbeddingOrMultiAvailable + ); const isDisabled = isDisabledDueToVlm || isDisabledDueToEmbedding || isReadOnly; // Tooltip priority: permission > VLM > Embedding const tooltipTitle = isReadOnly diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx index d09a06039..12a312c72 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolConfigModal.tsx @@ -1,4 +1,4 @@ -"use client"; +"use client"; import { useState, useEffect, useCallback, useMemo, useRef } from "react"; import { useTranslation } from "react-i18next"; @@ -30,6 +30,10 @@ import { API_ENDPOINTS } from "@/services/api"; import knowledgeBaseService from "@/services/knowledgeBaseService"; import log from "@/lib/logger"; import { isZhLocale, getLocalizedDescription } from "@/lib/utils"; +import { + isEmbeddingModelCompatible as isEmbeddingModelCompatibleBase, + isMultimodalConstraintMismatch as isMultimodalConstraintMismatchBase, +} from "@/lib/knowledgeBaseCompatibility"; export interface ToolConfigModalProps { isOpen: boolean; @@ -459,6 +463,86 @@ export default function ToolConfigModal({ } }, [configData]); + const currentMultiEmbeddingModel = useMemo(() => { + try { + const modelConfig = configData?.models; + return ( + modelConfig?.multiEmbedding?.modelName || + modelConfig?.multiEmbedding?.displayName || + null + ); + } catch { + return null; + } + }, [configData]); + + const hasEmbeddingModel = Boolean(currentEmbeddingModel); + const hasMultiEmbeddingModel = Boolean(currentMultiEmbeddingModel); + const canToggleMultimodalParam = hasEmbeddingModel && hasMultiEmbeddingModel; + const forcedMultimodalValue = useMemo(() => { + if (!hasEmbeddingModel && hasMultiEmbeddingModel) { + return true; + } + if (hasEmbeddingModel && !hasMultiEmbeddingModel) { + return false; + } + return null; + }, [hasEmbeddingModel, hasMultiEmbeddingModel]); + + const toolMultimodal = useMemo(() => { + const multimodalParam = currentParams.find( + (param) => param.name === "multimodal" + ); + const value = multimodalParam?.value; + if (typeof value === "boolean") { + return value; + } + if (typeof value === "string") { + const normalized = value.trim().toLowerCase(); + if (["true", "1", "yes", "y"].includes(normalized)) return true; + if (["false", "0", "no", "n"].includes(normalized)) return false; + } + return null; + }, [currentParams]); + + useEffect(() => { + if (tool?.name !== "knowledge_base_search") return; + if (forcedMultimodalValue === null) return; + + const index = currentParams.findIndex( + (param) => param.name === "multimodal" + ); + if (index < 0) return; + + const param = currentParams[index]; + if (param.value === forcedMultimodalValue) return; + + const updatedParams = [...currentParams]; + updatedParams[index] = { ...param, value: forcedMultimodalValue }; + setCurrentParams(updatedParams); + + const fieldName = `param_${index}`; + form.setFieldValue(fieldName, forcedMultimodalValue); + }, [tool?.name, forcedMultimodalValue, currentParams, form]); + + const isMultimodalConstraintMismatch = useCallback( + (kb: KnowledgeBase) => { + return isMultimodalConstraintMismatchBase(kb, toolMultimodal); + }, + [toolMultimodal] + ); + + const isEmbeddingModelCompatible = useCallback( + (kb: KnowledgeBase) => { + return isEmbeddingModelCompatibleBase( + kb, + currentEmbeddingModel, + currentMultiEmbeddingModel + ); + }, + [currentEmbeddingModel, currentMultiEmbeddingModel] + ); + // Check if a knowledge base can be selected const canSelectKnowledgeBase = useCallback( (kb: KnowledgeBase): boolean => { @@ -469,9 +553,16 @@ export default function ToolConfigModal({ return false; } + if (kb.source === "nexent") { + if (isMultimodalConstraintMismatch(kb)) { + return false; + } + return isEmbeddingModelCompatible(kb); + } + return true; }, - [currentEmbeddingModel] + [isEmbeddingModelCompatible, isMultimodalConstraintMismatch] ); // Track whether this is the first time opening the modal (reset when modal closes) @@ -1290,7 +1381,7 @@ export default function ToolConfigModal({ })} options={options.map((option) => ({ value: option, - label: option, + label: String(option), }))} /> ); @@ -1684,6 +1775,8 @@ export default function ToolConfigModal({ syncLoading={kbLoading} isSelectable={canSelectKnowledgeBase} currentEmbeddingModel={currentEmbeddingModel} + currentMultiEmbeddingModel={currentMultiEmbeddingModel} + toolMultimodal={toolMultimodal} difyConfig={ toolKbType === "dify_search" ? difyConfig diff --git a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx index f2bcc7f9e..99e21e8f8 100644 --- a/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx +++ b/frontend/app/[locale]/agents/components/agentConfig/tool/ToolTestPanel.tsx @@ -133,12 +133,12 @@ export default function ToolTestPanel({ // Check if this is the index_names parameter and KB selection is enabled const isIndexNamesParam = paramName === "index_names" && toolRequiresKbSelection; + if (isIndexNamesParam) { + // index_names is provided by KB selector config, no need to duplicate in input params. + return; + } - if (isIndexNamesParam && selectedKbIds.length > 0) { - // Use the selected KB IDs from configParams as default - parameterValues[paramName] = selectedKbIds; - formValues[`param_${paramName}`] = selectedKbIds; - } else if ( + if ( paramInfo && typeof paramInfo === "object" && paramInfo.default != null @@ -211,25 +211,6 @@ export default function ToolTestPanel({ if (!idsMatch) { form.setFieldValue(fieldName, selectedKbIds); - - // Also update the parameter values - if (selectedKbIds.length > 0) { - setParameterValues((prev) => ({ - ...prev, - index_names: selectedKbIds, - })); - // Update manual JSON input while preserving other values - setManualJsonInput((prev) => { - try { - const parsed = JSON.parse(prev); - parsed.index_names = selectedKbIds; - return JSON.stringify(parsed, null, 2); - } catch { - // If JSON is invalid, keep the current value - return prev; - } - }); - } } }, [selectedKbIds, toolRequiresKbSelection, form]); diff --git a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx index a5e7d52d1..d64873b5d 100644 --- a/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx +++ b/frontend/app/[locale]/knowledges/KnowledgeBaseConfiguration.tsx @@ -7,6 +7,7 @@ import { useRef, useLayoutEffect, useCallback, + useMemo, } from "react"; import { useTranslation } from "react-i18next"; @@ -45,6 +46,37 @@ import { } from "./contexts/DocumentContext"; import { useUIContext, UIProvider } from "./contexts/UIStateContext"; +const EMBEDDING_MODEL_OPTION_DELIMITER = "::"; +const normalizeEmbeddingModelType = (type: string) => + (type || "").trim().toLowerCase(); + +const toEmbeddingModelOptionValue = (displayName: string, type: string) => + `${displayName}${EMBEDDING_MODEL_OPTION_DELIMITER}${type}`; + +const parseEmbeddingModelOptionValue = (value: string) => { + const normalizedValue = (value || "").trim(); + const delimiterIndex = normalizedValue.lastIndexOf( + EMBEDDING_MODEL_OPTION_DELIMITER + ); + if (delimiterIndex >= 0) { + const displayName = normalizedValue.slice(0, delimiterIndex); + const type = normalizedValue.slice( + delimiterIndex + EMBEDDING_MODEL_OPTION_DELIMITER.length + ); + return { + displayName: displayName || "", + type: (type || "").trim(), + isMultimodal: + normalizeEmbeddingModelType(type || "") === "multi_embedding", + }; + } + return { + displayName: normalizedValue || "", + type: "", + isMultimodal: false, + }; +}; + // EmptyState component defined directly in this file interface EmptyStateProps { icon?: React.ReactNode | string; @@ -55,7 +87,7 @@ interface EmptyStateProps { } const EmptyState: React.FC = ({ - icon = "📋", + icon = "馃搵", title, description, action, @@ -129,8 +161,7 @@ function DataConfig({ isActive }: DataConfigProps) { const { token } = theme.useToken(); // Get available embedding models for knowledge base creation - const { availableEmbeddingModels } = useModelList({ enabled: true }); - + const { models } = useModelList({ enabled: true }); // Clear cache when component initializes useEffect(() => { localStorage.removeItem("preloaded_kb_data"); @@ -197,11 +228,41 @@ function DataConfig({ isActive }: DataConfigProps) { const [modelFilter, setModelFilter] = useState([]); const contentRef = useRef(null); - // Open warning modal when single Embedding model is not configured (ignore multi-embedding) + const availableEmbeddingModels = useMemo(() => { + return models.filter( + (model) => + (model.type === "embedding" || model.type === "multi_embedding") && + model.connect_status === "available" + ); + }, [models]); + + const resolveEmbeddingModelId = useCallback( + ({ + displayName, + isMultimodal, + }: { + displayName?: string; + isMultimodal?: boolean; + }) => { + const normalizedDisplayName = (displayName || "").trim(); + if (!normalizedDisplayName) return undefined; + + const modelType = isMultimodal ? "multi_embedding" : "embedding"; + return availableEmbeddingModels.find( + (model) => + model.displayName === normalizedDisplayName && model.type === modelType + )?.id; + }, + [availableEmbeddingModels] + ); + + // Open warning modal only when neither embedding nor multi-embedding is configured. useEffect(() => { - const singleEmbeddingModelName = modelConfig?.embedding?.modelName; - setShowEmbeddingWarning(!singleEmbeddingModelName); - }, [modelConfig?.embedding?.modelName]); + const singleEmbeddingModelName = modelConfig?.embedding?.modelName?.trim(); + const multiEmbeddingModelName = + modelConfig?.multiEmbedding?.modelName?.trim(); + setShowEmbeddingWarning(!singleEmbeddingModelName && !multiEmbeddingModelName); + }, [modelConfig?.embedding?.modelName, modelConfig?.multiEmbedding?.modelName]); // Add event listener for selecting new knowledge base useEffect(() => { @@ -369,11 +430,11 @@ function DataConfig({ isActive }: DataConfigProps) { // Directly call fetchKnowledgeBases to update knowledge base list data await fetchKnowledgeBases(false, true); } catch (error) { - log.error("获取知识库最新数据失败:", error); + log.error("鑾峰彇鐭ヨ瘑搴撴渶鏂版暟鎹け璐?", error); } }, 100); } catch (error) { - log.error("获取文档列表失败:", error); + log.error("鑾峰彇鏂囨。鍒楄〃澶辫触:", error); message.error(t("knowledgeBase.message.getDocumentsFailed")); docDispatch({ type: "ERROR", @@ -618,11 +679,30 @@ function DataConfig({ isActive }: DataConfigProps) { setNewKbName(defaultName); setNewKbIngroupPermission("READ_ONLY"); setNewKbGroupIds([]); - // Set default embedding model - prioritize config's default model, fall back to first available model - const configModel = modelConfig?.embedding?.modelName; - const defaultModel = configModel || (availableEmbeddingModels.length > 0 - ? availableEmbeddingModels[0].displayName - : ""); + // Set default embedding model: + // 1) configured embedding model, 2) configured multimodal model, 3) first available option. + const configEmbeddingModel = modelConfig?.embedding?.modelName?.trim() || ""; + const configMultiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + const preferredModel = [ + { modelName: configEmbeddingModel, type: "embedding" }, + { modelName: configMultiEmbeddingModel, type: "multi_embedding" }, + ].find( + ({ modelName, type }) => + !!modelName && + availableEmbeddingModels.some( + (model) => model.displayName === modelName && model.type === type + ) + ); + const defaultModel = + (preferredModel && + toEmbeddingModelOptionValue(preferredModel.modelName, preferredModel.type)) || + (availableEmbeddingModels[0] + ? toEmbeddingModelOptionValue( + availableEmbeddingModels[0].displayName, + availableEmbeddingModels[0].type + ) + : ""); setNewKbEmbeddingModel(defaultModel); setIsCreatingMode(true); setHasClickedUpload(false); // Reset upload button click state @@ -681,13 +761,22 @@ function DataConfig({ isActive }: DataConfigProps) { return; } + const parsedSelectedModel = + parseEmbeddingModelOptionValue(newKbEmbeddingModel); + const isMultimodal = parsedSelectedModel.isMultimodal; + const selectedModelId = resolveEmbeddingModelId({ + displayName: parsedSelectedModel.displayName, + isMultimodal: parsedSelectedModel.isMultimodal, + }); + const newKB = await createKnowledgeBase( newKbName.trim(), t("knowledgeBase.description.default"), "elasticsearch", newKbIngroupPermission, newKbGroupIds, - newKbEmbeddingModel + parsedSelectedModel.displayName, + isMultimodal ); if (!newKB) { @@ -702,7 +791,7 @@ function DataConfig({ isActive }: DataConfigProps) { setHasClickedUpload(false); setNewlyCreatedKbId(newKB.id); // Mark this KB as newly created - await uploadDocuments(newKB.id, filesToUpload); + await uploadDocuments(newKB.id, filesToUpload, selectedModelId); setUploadFiles([]); knowledgeBasePollingService @@ -738,7 +827,12 @@ function DataConfig({ isActive }: DataConfigProps) { } try { - await uploadDocuments(kbId, filesToUpload); + const activeKbModelId = resolveEmbeddingModelId({ + displayName: kbState.activeKnowledgeBase?.embeddingModel, + isMultimodal: kbState.activeKnowledgeBase?.is_multimodal, + }); + + await uploadDocuments(kbId, filesToUpload, activeKbModelId); setUploadFiles([]); knowledgeBasePollingService.triggerKnowledgeBaseListUpdate(true); @@ -887,7 +981,7 @@ function DataConfig({ isActive }: DataConfigProps) { = ({ knowledgeBaseId, documents, getFileIcon, - currentEmbeddingModel = null, - knowledgeBaseEmbeddingModel = "", + currentEmbeddingModel, + knowledgeBaseEmbeddingModel, onChunkCountChange, permission, }) => { @@ -128,55 +128,31 @@ const DocumentChunk: React.FC = ({ setTooltipResetKey((prev) => prev + 1); }, []); + const effectiveIndexName = knowledgeBaseId || knowledgeBaseName; + + const hasKnowledgeBaseModel = + Boolean(knowledgeBaseEmbeddingModel) && + knowledgeBaseEmbeddingModel !== "unknown"; + const hasCurrentModel = Boolean(currentEmbeddingModel); + // Determine if embedding models mismatch (specific condition for tooltip) const isEmbeddingModelMismatch = React.useMemo(() => { - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { + if (!hasKnowledgeBaseModel) { return false; } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]); + return !hasCurrentModel || currentEmbeddingModel !== knowledgeBaseEmbeddingModel; + }, [ + currentEmbeddingModel, + hasCurrentModel, + hasKnowledgeBaseModel, + knowledgeBaseEmbeddingModel, + ]); // Determine if in read-only mode (embedding model mismatch OR user has READ_ONLY permission) // Note: isReadOnlyMode is broader, includes model mismatch and other conditions const isReadOnlyMode = React.useMemo(() => { - // Check if user has READ_ONLY permission - if (permission === "READ_ONLY") { - return true; - } - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel, permission]); - - // Determine if search should be disabled (only when embedding model mismatch, NOT for READ_ONLY permission) - // This allows READ_ONLY users to still perform search - const isSearchDisabled = React.useMemo(() => { - if (!currentEmbeddingModel || !knowledgeBaseEmbeddingModel) { - return false; - } - if (knowledgeBaseEmbeddingModel === "unknown") { - return false; - } - return currentEmbeddingModel !== knowledgeBaseEmbeddingModel; - }, [currentEmbeddingModel, knowledgeBaseEmbeddingModel]); - - // Disabled tooltip message when embedding model mismatch - const disabledTooltipMessage = React.useMemo(() => { - if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") { - return t("document.chunk.tooltip.disabledDueToModelMismatch", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - }); - } - return ""; - }, [isEmbeddingModelMismatch, currentEmbeddingModel, knowledgeBaseEmbeddingModel, t]); + return permission === "READ_ONLY" || isEmbeddingModelMismatch; + }, [permission, isEmbeddingModelMismatch]); // Set active document when documents change useEffect(() => { @@ -201,14 +177,14 @@ const DocumentChunk: React.FC = ({ // Load chunks for active document with server-side pagination const loadChunks = React.useCallback(async () => { - if (!knowledgeBaseName || !activeDocumentKey) { + if (!effectiveIndexName || !activeDocumentKey) { return; } setLoading(true); try { const result = await knowledgeBaseService.previewChunksPaginated( - knowledgeBaseName, + effectiveIndexName, pagination.page, pagination.pageSize, activeDocumentKey @@ -240,7 +216,7 @@ const DocumentChunk: React.FC = ({ setLoading(false); } }, [ - knowledgeBaseName, + effectiveIndexName, activeDocumentKey, pagination.page, pagination.pageSize, @@ -321,16 +297,7 @@ const DocumentChunk: React.FC = ({ return; } - // Check embedding model consistency before searching - if (isEmbeddingModelMismatch && currentEmbeddingModel && knowledgeBaseEmbeddingModel && knowledgeBaseEmbeddingModel !== "unknown") { - message.error(t("document.chunk.error.searchFailed", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - })); - return; - } - - if (!knowledgeBaseName) { + if (!effectiveIndexName) { message.error(t("document.chunk.error.searchFailed")); return; } @@ -340,7 +307,7 @@ const DocumentChunk: React.FC = ({ try { const response = await knowledgeBaseService.hybridSearch( - knowledgeBaseId, + effectiveIndexName, trimmedValue, { topK: pagination.pageSize, @@ -352,11 +319,14 @@ const DocumentChunk: React.FC = ({ return { id: item.id || "", content: item.content || "", - path_or_url: item.path_or_url, + path_or_url: item.path_or_url || item.url || item.pathOrUrl, filename: item.filename, create_time: item.create_time, score: item.score, // Preserve search score for display - source_type: item.source_type, // Preserve source type for display + source_type: + item.source_type === "local" || item.source_type === "minio" + ? "file" + : item.source_type, // Preserve source type for display }; }); @@ -373,16 +343,12 @@ const DocumentChunk: React.FC = ({ setChunkSearchLoading(false); } }, [ - knowledgeBaseName, - knowledgeBaseId, + effectiveIndexName, message, pagination.pageSize, resetChunkSearch, searchValue, t, - isEmbeddingModelMismatch, - currentEmbeddingModel, - knowledgeBaseEmbeddingModel, ]); const refreshChunks = React.useCallback(async () => { @@ -454,7 +420,7 @@ const DocumentChunk: React.FC = ({ }; const handleChunkSubmit = async () => { - if (!knowledgeBaseName) { + if (!effectiveIndexName) { message.error(t("document.chunk.error.loadFailed")); return; } @@ -463,26 +429,12 @@ const DocumentChunk: React.FC = ({ return; } - // Check embedding model consistency before creating chunk - if (chunkModalMode === "create") { - if (knowledgeBaseEmbeddingModel && - knowledgeBaseEmbeddingModel !== "unknown" && - currentEmbeddingModel && - currentEmbeddingModel !== knowledgeBaseEmbeddingModel) { - message.error(t("document.chunk.error.createFailed", { - currentModel: currentEmbeddingModel, - knowledgeBaseModel: knowledgeBaseEmbeddingModel - })); - return; - } - } - try { const values = await chunkForm.validateFields(); setChunkSubmitting(true); if (chunkModalMode === "create") { const filenamePayload = values.filename?.trim() || undefined; - await knowledgeBaseService.createChunk(knowledgeBaseName, { + await knowledgeBaseService.createChunk(effectiveIndexName, { content: values.content, filename: filenamePayload, path_or_url: activeDocumentKey, @@ -503,7 +455,7 @@ const DocumentChunk: React.FC = ({ return; } await knowledgeBaseService.updateChunk( - knowledgeBaseName, + effectiveIndexName, editingChunk.id, { content: values.content, @@ -541,7 +493,7 @@ const DocumentChunk: React.FC = ({ message.error(t("document.chunk.error.missingChunkId")); return; } - if (!knowledgeBaseName) { + if (!effectiveIndexName) { message.error(t("document.chunk.error.deleteFailed")); return; } @@ -556,7 +508,7 @@ const DocumentChunk: React.FC = ({ danger: true, onOk: async () => { try { - await knowledgeBaseService.deleteChunk(knowledgeBaseName, chunk.id); + await knowledgeBaseService.deleteChunk(effectiveIndexName, chunk.id); message.success(t("document.chunk.success.delete")); forceCloseTooltips(); // Update chunk count immediately for better UX @@ -761,11 +713,11 @@ const DocumentChunk: React.FC = ({
{chunk.source_type === "datamate" - ? t("document.chunk.source.datamate", "来源: Datamate") + ? t("document.chunk.source.datamate", "\u6765\u6e90: Datamate") : chunk.source_type === "file" || chunk.source_type === "minio" || chunk.source_type === "local" - ? t("document.chunk.source.nexent", "来源: Nexent") + ? t("document.chunk.source.nexent", "\u6765\u6e90: Nexent") : ""}
@@ -805,57 +757,37 @@ const DocumentChunk: React.FC = ({ {/* Search and Add Button Bar */}
- {/* Wrap search input with tooltip when model mismatch */} - {isEmbeddingModelMismatch ? ( - - - setSearchValue(e.target.value)} - onPressEnter={() => { + setSearchValue(e.target.value)} + onPressEnter={() => { + void handleSearch(); + }} + style={{ width: 320 }} + suffix={ +
+ {searchValue && ( +
- } - /> - )} +
+ } + />
{/* Create Chunk button - hide when user has READ_ONLY permission */} {!isReadOnlyMode && ( @@ -864,7 +796,6 @@ const DocumentChunk: React.FC = ({ type="text" icon={} onClick={openCreateChunkModal} - disabled={isEmbeddingModelMismatch} > )} diff --git a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx index 06940d9f0..50a845142 100644 --- a/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx +++ b/frontend/app/[locale]/knowledges/components/document/DocumentList.tsx @@ -79,6 +79,8 @@ interface DocumentListProps { availableEmbeddingModels?: ModelOption[]; selectedEmbeddingModel?: string; onEmbeddingModelChange?: (value: string) => void; + isMultimodal?: boolean; + onMultimodalChange?: (value: boolean) => void; permission?: string; // User's permission for this knowledge base (READ_ONLY, EDIT, etc.) // Upload related props @@ -122,6 +124,8 @@ const DocumentListContainer = forwardRef( availableEmbeddingModels, selectedEmbeddingModel, onEmbeddingModelChange, + isMultimodal = false, + onMultimodalChange, permission, // Upload related props @@ -238,6 +242,8 @@ const DocumentListContainer = forwardRef( // Determine if user has read-only permission const isReadOnlyMode = permission === "READ_ONLY"; + const canToggleMultimodal = + isCreatingMode && typeof onMultimodalChange === "function"; // Permission options with icons shown inside dropdown const permissionOptions = [ @@ -503,11 +509,29 @@ const DocumentListContainer = forwardRef( onChange={onEmbeddingModelChange} style={{ minWidth: 200, justifyContent: "center", alignItems: "flex-end" }} placeholder={t("knowledgeBase.create.embeddingModelPlaceholder") || "Select embedding model"} - options={(availableEmbeddingModels || []).map((model) => ({ - value: model.displayName, - label: model.displayName, - disabled: model.connect_status === "unavailable", - }))} + allowClear={false} + options={[ + { + label: t("modelConfig.option.embeddingModel"), + options: (availableEmbeddingModels || []) + .filter((model) => model.type === "embedding") + .map((model) => ({ + value: `${model.displayName}::${model.type}`, + label: model.displayName, + disabled: model.connect_status === "unavailable", + })), + }, + { + label: t("modelConfig.option.multiEmbeddingModel"), + options: (availableEmbeddingModels || []) + .filter((model) => model.type === "multi_embedding") + .map((model) => ({ + value: `${model.displayName}::${model.type}`, + label: model.displayName, + disabled: model.connect_status === "unavailable", + })), + }, + ].filter((group) => group.options.length > 0)} /> )} {/* User groups multi-select */} @@ -615,7 +639,7 @@ const DocumentListContainer = forwardRef(
; isLoading?: boolean; syncLoading?: boolean; onClick: (kb: KnowledgeBase) => void; @@ -56,7 +59,7 @@ interface KnowledgeBaseListProps { const KnowledgeBaseList: React.FC = ({ knowledgeBases, activeKnowledgeBase, - currentEmbeddingModel, + configuredEmbeddingModels = [], isLoading = false, syncLoading = false, onClick, @@ -127,6 +130,34 @@ const KnowledgeBaseList: React.FC = ({ return `knowledgeBase.ingroup.permission.${permission || "DEFAULT"}`; }; + const configuredModelTypesByName = useMemo(() => { + const map = new Map>(); + configuredEmbeddingModels.forEach((model) => { + const modelName = (model.displayName || "").trim(); + const modelType = (model.type || "").trim().toLowerCase(); + if (!modelName) return; + if (modelType !== "embedding" && modelType !== "multi_embedding") return; + if (!map.has(modelName)) { + map.set(modelName, new Set()); + } + map.get(modelName)!.add(modelType); + }); + return map; + }, [configuredEmbeddingModels]); + + const isModelMismatch = (kb: KnowledgeBase) => { + if (kb.embeddingModel === "unknown") return false; + if (kb.source === "datamate") return false; + const modelTypes = configuredModelTypesByName.get( + (kb.embeddingModel || "").trim() + ); + return !modelTypes; + }; + + const hasIndexedDocumentsAndChunks = (kb: KnowledgeBase) => { + return (kb.documentCount || 0) > 0 && (kb.chunkCount || 0) > 0; + }; + // Search and filter states const [searchKeyword, setSearchKeyword] = useState(""); const [selectedSources, setSelectedSources] = useState([]); @@ -579,6 +610,21 @@ const KnowledgeBaseList: React.FC = ({ })} )} + {kb.is_multimodal && + hasIndexedDocumentsAndChunks(kb) && ( + + multimodal + + )} + {isModelMismatch(kb) && ( + + {t("knowledgeBase.tag.modelMismatch")} + + )} {/* User group tags - only show when not PRIVATE */} diff --git a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx index b956dd919..63d9ad1c2 100644 --- a/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/DocumentContext.tsx @@ -112,7 +112,7 @@ export const DocumentContext = createContext<{ state: DocumentState; dispatch: React.Dispatch; fetchDocuments: (kbId: string, forceRefresh?: boolean, kbSource?: string) => Promise; - uploadDocuments: (kbId: string, files: File[]) => Promise; + uploadDocuments: (kbId: string, files: File[], modelId?: number) => Promise; deleteDocument: (kbId: string, docId: string) => Promise; }>({ state: { @@ -202,11 +202,11 @@ export const DocumentProvider: React.FC = ({ children }) }, [state.loadingKbIds, state.documentsMap, t]); // Upload documents to a knowledge base - const uploadDocuments = useCallback(async (kbId: string, files: File[]) => { + const uploadDocuments = useCallback(async (kbId: string, files: File[], modelId?: number) => { dispatch({ type: DOCUMENT_ACTION_TYPES.SET_UPLOADING, payload: true }); try { - await knowledgeBaseService.uploadDocuments(kbId, files); + await knowledgeBaseService.uploadDocuments(kbId, files, undefined, modelId); // Set loading state before fetching latest documents dispatch({ type: DOCUMENT_ACTION_TYPES.SET_LOADING_DOCUMENTS, payload: true }); @@ -265,4 +265,4 @@ export const DocumentProvider: React.FC = ({ children }) {children} ); -}; \ No newline at end of file +}; diff --git a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx index 5985c4b08..947cac8aa 100644 --- a/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx +++ b/frontend/app/[locale]/knowledges/contexts/KnowledgeBaseContext.tsx @@ -110,7 +110,8 @@ export const KnowledgeBaseContext = createContext<{ source?: string, ingroup_permission?: string, group_ids?: number[], - embeddingModel?: string + embeddingModel?: string, + is_multimodal?: boolean ) => Promise; deleteKnowledgeBase: (id: string) => Promise; selectKnowledgeBase: (id: string) => void; @@ -125,6 +126,7 @@ export const KnowledgeBaseContext = createContext<{ selectedIds: [], activeKnowledgeBase: null, currentEmbeddingModel: null, + currentMultiEmbeddingModel: null, isLoading: false, syncLoading: false, error: null, @@ -159,6 +161,7 @@ export const KnowledgeBaseProvider: React.FC = ({ selectedIds: [], activeKnowledgeBase: null, currentEmbeddingModel: null, + currentMultiEmbeddingModel: null, isLoading: false, syncLoading: false, error: null, @@ -168,11 +171,6 @@ export const KnowledgeBaseProvider: React.FC = ({ // Check if knowledge base is selectable - memoized with useCallback const isKnowledgeBaseSelectable = useCallback( (kb: KnowledgeBase): boolean => { - // If no current embedding model is set, not selectable - if (!state.currentEmbeddingModel) { - return false; - } - // Check if knowledge base has content (documents or chunks) const hasContent = (kb.documentCount || 0) > 0 || (kb.chunkCount || 0) > 0; @@ -187,22 +185,46 @@ export const KnowledgeBaseProvider: React.FC = ({ return true; } - // For local knowledge bases, only selectable when model exactly matches current model - return ( - kb.embeddingModel === "unknown" || - kb.embeddingModel === state.currentEmbeddingModel - ); + if (kb.embeddingModel === "unknown") { + return true; + } + + const currentEmbeddingModel = state.currentEmbeddingModel?.trim() || ""; + const currentMultiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + + if (kb.is_multimodal) { + // Multimodal KB is selectable as long as current multimodal model is configured. + return !!currentMultiEmbeddingModel; + } + + // Text KB is selectable as long as current embedding model is configured. + return !!currentEmbeddingModel; }, - [state.currentEmbeddingModel] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel] ); // Check if knowledge base has model mismatch (for display purposes) - // Note: Always return false to remove model mismatch restrictions const hasKnowledgeBaseModelMismatch = useCallback( (kb: KnowledgeBase): boolean => { - return false; + if (kb.embeddingModel === "unknown") { + return false; + } + if (kb.source === "datamate") { + return false; + } + + if (kb.is_multimodal) { + const multiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + // Only show warning when the required current model is not configured. + return !multiEmbeddingModel; + } + + // Only show warning when the required current model is not configured. + return !state.currentEmbeddingModel; }, - [] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel] ); // Load knowledge base data (supports force fetch from server and load selected status) - optimized with useCallback @@ -311,17 +333,31 @@ export const KnowledgeBaseProvider: React.FC = ({ source: string = "elasticsearch", ingroup_permission?: string, group_ids?: number[], - embeddingModel?: string + embeddingModel?: string, + is_multimodal?: boolean ) => { try { + const selectedEmbeddingModel = embeddingModel?.trim() || ""; + const defaultMultiEmbeddingModel = + modelConfig?.multiEmbedding?.modelName?.trim() || ""; + const resolvedIsMultimodal = + typeof is_multimodal === "boolean" + ? is_multimodal + : !!defaultMultiEmbeddingModel && + selectedEmbeddingModel === defaultMultiEmbeddingModel; + const fallbackEmbeddingModel = resolvedIsMultimodal + ? defaultMultiEmbeddingModel + : state.currentEmbeddingModel || ""; + const resolvedEmbeddingModel = + selectedEmbeddingModel || fallbackEmbeddingModel; const newKB = await knowledgeBaseService.createKnowledgeBase({ name, description, source, - // Use provided embeddingModel if available, otherwise fall back to current model or default - embeddingModel: embeddingModel || state.currentEmbeddingModel || "", + embeddingModel: resolvedEmbeddingModel, ingroup_permission, group_ids, + is_multimodal: resolvedIsMultimodal, }); return newKB; } catch (error) { @@ -333,7 +369,7 @@ export const KnowledgeBaseProvider: React.FC = ({ return null; } }, - [state.currentEmbeddingModel, t] + [modelConfig?.multiEmbedding?.modelName, state.currentEmbeddingModel, t] ); // Delete knowledge base - memoized with useCallback @@ -609,6 +645,7 @@ export const KnowledgeBaseProvider: React.FC = ({ selectKnowledgeBase, setActiveKnowledgeBase, isKnowledgeBaseSelectable, + hasKnowledgeBaseModelMismatch, refreshKnowledgeBaseData, refreshKnowledgeBaseDataWithDataMate, ] diff --git a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx index 7cbf5192e..b0d86d2b8 100644 --- a/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelAddDialog.tsx @@ -1,4 +1,4 @@ -import { useMemo, useState, useCallback, useEffect } from "react"; +import { useMemo, useState, useCallback, useEffect } from "react"; import { useTranslation } from "react-i18next"; import { Modal, Select, Input, Button, Switch, Tooltip, App } from "antd"; diff --git a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx index 5e498e8de..7b8479385 100644 --- a/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelEditDialog.tsx @@ -1,4 +1,4 @@ -import { useState, useEffect } from 'react' +import { useState, useEffect } from 'react' import { useTranslation } from 'react-i18next' import { Modal, Input, Button, App } from "antd"; @@ -480,4 +480,4 @@ export const ProviderConfigEditDialog = ({
) -} \ No newline at end of file +} diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx index e20e74876..5e91f71f1 100644 --- a/frontend/app/[locale]/models/components/modelConfig.tsx +++ b/frontend/app/[locale]/models/components/modelConfig.tsx @@ -527,6 +527,7 @@ export const ModelConfigSection = forwardRef< try { const isConnected = await modelService.verifyCustomModel( modelName, + modelType, signal ); @@ -603,7 +604,7 @@ export const ModelConfigSection = forwardRef< throttleTimerRef.current = setTimeout(async () => { try { // Use modelService to verify model - const isConnected = await modelService.verifyCustomModel(displayName); + const isConnected = await modelService.verifyCustomModel(displayName, modelType); // Update model status updateModelStatus( diff --git a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx index 42ca403e2..260e83b3b 100644 --- a/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx +++ b/frontend/app/[locale]/tenant-resources/components/resources/ModelList.tsx @@ -1,4 +1,4 @@ -"use client"; +"use client"; import React, { useState } from "react"; import { useTranslation } from "react-i18next"; @@ -92,7 +92,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { }; // Handle checking model connectivity - const handleCheckConnectivity = async (displayName: string) => { + const handleCheckConnectivity = async (displayName: string, modelType: string) => { if (!tenantId) { message.error(t("tenantResources.tenants.tenantIdRequired")); return; @@ -100,7 +100,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) { setCheckingConnectivity((prev) => new Set(prev).add(displayName)); try { - const isConnected = await modelService.verifyCustomModel(displayName); + const isConnected = await modelService.verifyCustomModel(displayName, modelType); if (isConnected) { message.success(t("tenantResources.models.connectivitySuccess")); } else { @@ -194,7 +194,7 @@ export default function ModelList({ tenantId }: { tenantId: string | null }) {