diff --git a/AGENTS.md b/AGENTS.md index 7798227b1..a631eb50f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -8,7 +8,7 @@ -When users ask you to perform tasks, check if any of the available skills below can help complete the task more effectively. Skills provide specialized capabilities and domain knowledge. +When users ask to perform tasks, check if any of the available skills below can help complete the task more effectively. Skills provide specialized capabilities and domain knowledge. How to use skills: - Invoke: `npx openskills read ` (run in your shell) @@ -40,3 +40,129 @@ Usage notes: + +--- + +## Project Overview + +Nexent is a zero-code platform for auto-generating AI agents. Monorepo with: +- `backend/` - FastAPI HTTP API +- `sdk/nexent/` - Core agent framework (pip package) +- `frontend/` - Next.js web UI +- `docker/` & `k8s/` - Deployment configs + +--- + +## Developer Commands + +### Backend (Python 3.10) + +```bash +# Setup +cd backend && uv sync --extra data-process --extra test + +# Install SDK for development +cd backend && uv pip install -e "../sdk[dev]" +``` + +### Run Tests + +```bash +# From project root, with backend venv activated +source backend/.venv/bin/activate && python test/run_all_test.py + +# Single test file +pytest test/backend/apps/test_agent_app.py -v +``` + +### Frontend (Next.js) + +```bash +cd frontend +npm run dev # Development server +npm run check-all # type-check + lint + format + build +``` + +### Docker Deployment + +```bash +cd docker +cp .env.example .env # Fill required configs +bash deploy.sh # Interactive deployment +``` + +--- + +## Architecture + +### Environment Variables + +**Single source of truth**: `backend/consts/const.py` + +- NO direct `os.getenv()` / `os.environ.get()` outside this file +- SDK (`sdk/nexent/`) NEVER reads env vars - accepts config via parameters +- Services read from `consts.const` and pass to SDK + +### Backend Layer Structure + +| Layer | Path | Responsibility | +|-------|------|----------------| +| Apps | `backend/apps/` | HTTP boundary: parse input, call services, map exceptions to HTTP | +| Services | `backend/services/` | Business logic orchestration, raise domain exceptions | +| Consts | `backend/consts/` | Env vars (`const.py`), exceptions (`exceptions.py`), error codes | + +**Exception flow**: Services raise domain exceptions → Apps map to HTTP status codes + +--- + +## Database Migrations + +**Location**: `docker/sql/*.sql` (versioned migration scripts) + +**Critical rule**: When adding columns/tables via migration script: +- Update `docker/init.sql` (Docker Compose fresh deploy) +- Update `k8s/helm/nexent/charts/nexent-common/files/init.sql` (K8s fresh deploy) + +**Version**: Tracked in `backend/consts/const.py` as `APP_VERSION` + +--- + +## Testing Conventions + +- pytest only (no unittest) +- Mock at import site with fully-qualified path: + ```python + mocker.patch("backend.services.agent_service.AgentService.run", return_value={...}) + ``` +- Async tests: `@pytest.mark.asyncio` +- Test structure: `test/backend/` and `test/sdk/` + +--- + +## Code Style + +- English-only comments and docstrings (enforced by `.cursor/rules/english_comments.mdc`) +- Import order: stdlib → third-party → project +- Line length: 119 (sdk ruff config) + +--- + +## Key Files + +| File | Purpose | +|------|---------| +| `backend/consts/const.py` | All env var definitions, APP_VERSION | +| `backend/consts/exceptions.py` | Domain exceptions (AgentRunException, LimitExceededError, etc.) | +| `docker/init.sql` | Database schema for Docker Compose | +| `k8s/helm/.../init.sql` | Database schema for Kubernetes | +| `test/run_all_test.py` | Test runner with coverage | + +--- + +## Reference Files + +Existing instruction files with detailed rules: +- `CLAUDE.md` - Backend architecture, env var management, app/service layer rules +- `.cursor/rules/environment_variable.mdc` - Env var centralization +- `.cursor/rules/pytest_unit_test_rules.mdc` - Testing patterns +- `.cursor/rules/english_comments.mdc` - Comment language enforcement \ No newline at end of file diff --git a/backend/agents/create_agent_info.py b/backend/agents/create_agent_info.py index 17eb17484..0f6591a54 100644 --- a/backend/agents/create_agent_info.py +++ b/backend/agents/create_agent_info.py @@ -8,8 +8,20 @@ from nexent.core.utils.observer import MessageObserver from nexent.core.agents.agent_model import AgentRunInfo, ModelConfig, AgentConfig, ToolConfig, ExternalA2AAgentConfig, AgentHistory, AgentVerificationConfig from nexent.core.agents.agent_context import ContextManagerConfig +from nexent.core.models.capacity_resolver import ( + ModelCapacitySnapshot, + ProviderCapabilityUnknown, + ResolverError, + resolve_capacity, +) +from nexent.core.models.capacity_budget import ( + RequestBudgetOverrides, + SafeInputBudgetCalculator, +) from nexent.memory.memory_service import search_memory_in_levels +from consts.capability_profiles import CATALOG as CAPABILITY_CATALOG + from services.file_management_service import get_llm_model, validate_urls_access from services.vectordatabase_service import ( ElasticSearchService, @@ -44,6 +56,212 @@ logger.setLevel(logging.DEBUG) +# Safe fallback for context-manager token_threshold when no capacity is known. +# Used only when the resolver fails (uncataloged model with no operator-supplied +# hard capacity). Sized to cover the typical 32K-context band shared by the +# majority of production LLMs (GPT-3.5 16K, GLM-4 32K, Qwen2 32K, Llama 3 +# 32K, etc.). Larger windows benefit only by skipping a few extra +# compressions; smaller ones surface as a clear provider token-overflow +# error at request time rather than silent truncation. Will be removed +# once enforcement phase requires snapshots end to end. +_TOKEN_THRESHOLD_LEGACY_FALLBACK = 32768 + +_OPERATOR_OVERRIDE_FIELDS = ( + "context_window_tokens", + "max_input_tokens", + "max_output_tokens", + "default_output_reserve_tokens", + "tokenizer_family", +) + +# Per-process dedup for the "model has no capacity configured" warning. +# Without this, every agent run logs the same line, drowning real signal. +# Keyed by model_id; cleared only on process restart. +# Guarded by a lock because the check-then-add window is not atomic on its +# own: two threads can both pass the `in` check before either calls `add`, +# leading to duplicate WARNING lines defeating the per-process dedup. +_CAPACITY_WARNING_EMITTED: set = set() +_CAPACITY_WARNING_LOCK = threading.Lock() + + +def _operator_overrides_from_model_info(model_info: Optional[dict]) -> dict: + """Extract the W1 operator-override fields from a model_record_t row.""" + if not isinstance(model_info, dict): + return {} + overrides = {} + for field in _OPERATOR_OVERRIDE_FIELDS: + value = model_info.get(field) + if value is not None: + overrides[field] = value + return overrides + + +def _dominant_capacity_source(field_sources: dict) -> Optional[str]: + values = [value for value in field_sources.values() if value] + if not values: + return None + for preferred in ("operator", "profile", "provider_candidate", "legacy", "unknown"): + if preferred in values: + return preferred + return values[0] + + +def _capacity_snapshot_for_monitoring(snapshot: Any) -> dict: + data = snapshot.model_dump() if hasattr(snapshot, "model_dump") else dict(snapshot) + return { + "provider": data.get("provider"), + "model_name": data.get("model_name"), + "context_window_tokens": data.get("context_window_tokens"), + "default_output_reserve_tokens": data.get("default_output_reserve_tokens"), + "capability_profile_version": data.get("capability_profile_version"), + "capacity_source": _dominant_capacity_source(data.get("field_sources") or {}), + "requested_output_tokens": data.get("requested_output_tokens"), + "provider_input_limit_tokens": data.get("provider_input_limit_tokens"), + "tokenizer_family": data.get("tokenizer_family"), + "counting_mode": data.get("counting_mode"), + "unknown_capabilities": data.get("unknown_capabilities") or [], + "capacity_fingerprint": data.get("fingerprint"), + } + + +def _safe_input_budget_for_monitoring(snapshot: Any) -> dict: + return snapshot.model_dump() if hasattr(snapshot, "model_dump") else dict(snapshot) + + +def _resolve_safe_input_budget( + *, + capacity_snapshot: Optional[ModelCapacitySnapshot], + tenant_id: str, + agent_requested_output_tokens: Optional[int], + request_requested_output_tokens: Optional[int], +) -> Optional[dict]: + """Resolve the W2 budget snapshot before context assembly begins.""" + if capacity_snapshot is None: + return None + + request_overrides = None + if request_requested_output_tokens is not None: + request_overrides = RequestBudgetOverrides( + requested_output_tokens=request_requested_output_tokens, + ) + + output_reserve_source = ( + "agent" if agent_requested_output_tokens is not None else "model_default" + ) + snapshot = SafeInputBudgetCalculator().calculate_safe_input_budget( + capacity_snapshot=capacity_snapshot, + reserve_policy=tenant_config_manager.get_capacity_reserve_policy(tenant_id), + request_overrides=request_overrides, + requested_output_tokens=agent_requested_output_tokens, + output_reserve_source=output_reserve_source, + ) + logger.info( + "W2 safe input budget resolved: tenant_id=%s model=%s requested_output_tokens=%s " + "soft_input_budget_tokens=%s hard_input_budget_tokens=%s fingerprint=%s warnings=%s", + tenant_id, + snapshot.model_name, + snapshot.requested_output_tokens, + snapshot.soft_input_budget_tokens, + snapshot.hard_input_budget_tokens, + snapshot.fingerprint, + list(snapshot.warnings), + ) + return _safe_input_budget_for_monitoring(snapshot) + + +def _resolve_input_budget( + model_info: Optional[dict], +) -> tuple[int, Optional[dict], Optional[ModelCapacitySnapshot]]: + """Resolve the context-manager input budget for a model_record_t row. + + Calls ModelCapacityResolver with the catalog + operator overrides. Returns + snapshot.provider_input_limit_tokens and monitoring fields on success. + Falls back to _TOKEN_THRESHOLD_LEGACY_FALLBACK with no snapshot when + capacity is unknown — this is the migration-window behavior before all + model rows are backfilled. + """ + if not isinstance(model_info, dict): + return _TOKEN_THRESHOLD_LEGACY_FALLBACK, None, None + provider_raw = model_info.get("model_factory") + provider = provider_raw.lower().strip() if isinstance(provider_raw, str) else "" + model_id = model_info.get("model_name") or "" + provider_missing_detail = None + if not provider: + provider_missing_detail = ( + "model_factory/provider is missing; capacity catalog matching is disabled" + ) + try: + snapshot = resolve_capacity( + model_id=model_id, + provider=provider, + operator_overrides=_operator_overrides_from_model_info(model_info), + capability_profiles=CAPABILITY_CATALOG, + ) + logger.debug( + "Capacity resolved for (%s, %s): input_limit=%s source=%s profile=%s fingerprint=%s", + provider, model_id, + snapshot.provider_input_limit_tokens, + dict(snapshot.field_sources), + snapshot.capability_profile_version, + snapshot.fingerprint, + ) + return ( + snapshot.provider_input_limit_tokens, + _capacity_snapshot_for_monitoring(snapshot), + snapshot, + ) + except ProviderCapabilityUnknown: + _warn_missing_capacity_once( + model_info, provider, model_id, detail=provider_missing_detail, + ) + return _TOKEN_THRESHOLD_LEGACY_FALLBACK, None, None + except ResolverError as exc: + _warn_missing_capacity_once( + model_info, provider, model_id, detail=str(exc), + ) + return _TOKEN_THRESHOLD_LEGACY_FALLBACK, None, None + + +def _warn_missing_capacity_once( + model_info: Optional[dict], + provider: str, + model_id_str: str, + detail: Optional[str] = None, +) -> None: + """Log one WARNING per process per model when capacity is not configured. + + Plain-English message aimed at operators reading backend logs. Tells + them what is disabled, which model is affected, and how to fix it + through the existing UI. + """ + db_model_id = ( + model_info.get("model_id") if isinstance(model_info, dict) else None + ) + dedup_key = db_model_id if db_model_id is not None else f"{provider}/{model_id_str}" + # Test-and-set inside the lock so concurrent first-time callers don't + # both make it past the membership check. Logging happens outside the + # lock to avoid serialising I/O across all warning paths. + with _CAPACITY_WARNING_LOCK: + if dedup_key in _CAPACITY_WARNING_EMITTED: + return + _CAPACITY_WARNING_EMITTED.add(dedup_key) + + reason = ( + f"resolver error: {detail}" + if detail + else "no context_window_tokens or max_output_tokens configured" + ) + logger.warning( + "Output token cap and budget consistency check are not enforced for " + "model '%s' (model_id=%s, provider=%s) because %s. " + "To enable enforcement, open the Nexent model management UI, edit " + "this model, and fill in 'Context window tokens' and 'Max output " + "tokens'. Falling back to a default context threshold of %s tokens.", + model_id_str, db_model_id, provider, reason, + _TOKEN_THRESHOLD_LEGACY_FALLBACK, + ) + + def _normalize_tool_params_request(tool_params: Optional[ToolParamsRequest | Dict[str, Any]]) -> ToolParamsRequest: """Normalize request-scoped tool parameter overrides into a ToolParamsRequest.""" if tool_params is None: @@ -336,7 +554,17 @@ async def create_model_config_list(tenant_id): ssl_verify=record.get("ssl_verify", True), model_factory=record.get("model_factory"), timeout_seconds=record.get("timeout_seconds"), - concurrency_limit=record.get("concurrency_limit"))) + concurrency_limit=record.get("concurrency_limit"), + # W1 step 6: pass capacity columns through so SDK can + # honor operator-configured values end to end. + max_output_tokens=record.get("max_output_tokens"), + max_tokens=record.get("max_tokens"), + context_window_tokens=record.get("context_window_tokens"), + max_input_tokens=record.get("max_input_tokens"), + default_output_reserve_tokens=record.get("default_output_reserve_tokens"), + tokenizer_family=record.get("tokenizer_family"), + capacity_source=record.get("capacity_source"), + capability_profile_version=record.get("capability_profile_version"))) # fit for old version, main_model and sub_model use default model main_model_config = tenant_config_manager.get_model_config( key=MODEL_CONFIG_MAPPING["llm"], tenant_id=tenant_id) @@ -373,6 +601,7 @@ async def create_agent_config( allow_memory_search: bool = True, version_no: int = 0, override_model_id: int | None = None, + request_requested_output_tokens: int | None = None, tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None, ): normalized_tool_params = _normalize_tool_params_request(tool_params) @@ -579,14 +808,37 @@ async def create_agent_config( model_id_to_use = override_model_id if override_model_id else agent_info.get("model_id") model_info = None - model_max_tokens = 10000 if model_id_to_use is not None: model_info = get_model_by_model_id(model_id_to_use, tenant_id=tenant_id) model_name = model_info["display_name"] if model_info is not None else "main_model" - if model_info is not None and model_info.get("max_tokens"): - model_max_tokens = model_info["max_tokens"] + # W1 step 6: derive input budget via ModelCapacityResolver instead of + # treating model_info["max_tokens"] (a deprecated output cap) as a + # context threshold. Falls back to a safe constant when capacity is + # unknown during the migration window. + input_budget, capacity_snapshot, resolved_capacity_snapshot = ( + _resolve_input_budget(model_info) + ) else: model_name = "main_model" + input_budget = _TOKEN_THRESHOLD_LEGACY_FALLBACK + capacity_snapshot = None + resolved_capacity_snapshot = None + + requested_output_tokens = agent_info.get("requested_output_tokens") + safe_input_budget_snapshot = _resolve_safe_input_budget( + capacity_snapshot=resolved_capacity_snapshot, + tenant_id=tenant_id, + agent_requested_output_tokens=requested_output_tokens, + request_requested_output_tokens=request_requested_output_tokens, + ) + if safe_input_budget_snapshot is not None: + soft_input_budget_tokens = safe_input_budget_snapshot["soft_input_budget_tokens"] + hard_input_budget_tokens = safe_input_budget_snapshot["hard_input_budget_tokens"] + context_token_threshold = soft_input_budget_tokens + else: + soft_input_budget_tokens = 0 + hard_input_budget_tokens = 0 + context_token_threshold = input_budget logger.info( "Agent main LLM: agent_id=%s, model_id=%s, display_name=%s, model_name=%s", @@ -623,7 +875,9 @@ async def create_agent_config( ) cm_config = ContextManagerConfig( enabled=enable_context_manager, - token_threshold=model_max_tokens, + token_threshold=context_token_threshold, + soft_input_budget_tokens=soft_input_budget_tokens, + hard_input_budget_tokens=hard_input_budget_tokens, ) agent_config = AgentConfig( name="undefined" if agent_info["name"] is None else agent_info["name"], @@ -636,12 +890,15 @@ async def create_agent_config( ), tools=tool_list + _get_skill_script_tools(agent_id, tenant_id, version_no), max_steps=agent_info.get("max_steps", 15), + requested_output_tokens=requested_output_tokens, model_name=model_name, provide_run_summary=agent_info.get("provide_run_summary", False), managed_agents=managed_agents, external_a2a_agents=external_a2a_agents, context_manager_config=cm_config, context_components=context_components, + capacity_snapshot=capacity_snapshot, + safe_input_budget_snapshot=safe_input_budget_snapshot, verification_config=AgentVerificationConfig.model_validate(agent_info.get("verification_config") or {}), ) return agent_config @@ -1054,6 +1311,7 @@ async def create_agent_run_info( is_debug: bool = False, override_version_no: int | None = None, override_model_id: int | None = None, + requested_output_tokens: int | None = None, tool_params: Optional[ToolParamsRequest | Dict[str, Any]] = None, ): # Determine which version_no to use based on is_debug flag @@ -1086,6 +1344,8 @@ async def create_agent_run_info( } if override_model_id is not None: create_config_kwargs["override_model_id"] = override_model_id + if requested_output_tokens is not None: + create_config_kwargs["request_requested_output_tokens"] = requested_output_tokens agent_config = await create_agent_config(**create_config_kwargs, tool_params=tool_params) @@ -1141,6 +1401,12 @@ async def create_agent_run_info( agent_config=agent_config, mcp_host=mcp_host, history=converted_history, - stop_event=threading.Event() + stop_event=threading.Event(), + capacity_snapshot=getattr(agent_config, "capacity_snapshot", None), + safe_input_budget_snapshot=getattr( + agent_config, + "safe_input_budget_snapshot", + None, + ), ) return agent_run_info diff --git a/backend/apps/model_managment_app.py b/backend/apps/model_managment_app.py index 53dfebb02..a92937e12 100644 --- a/backend/apps/model_managment_app.py +++ b/backend/apps/model_managment_app.py @@ -16,7 +16,10 @@ from consts.model import ( BatchCreateModelsRequest, + CapacitySuggestionFields, ModelRequest, + ModelCapacitySuggestionRequest, + ModelCapacitySuggestionResponse, ProviderModelRequest, ManageTenantModelListRequest, ManageTenantModelListResponse, @@ -28,6 +31,7 @@ ManageProviderModelListRequest, ManageProviderModelCreateRequest, ) +from consts.const import CAPACITY_SUGGESTION_ENABLED from fastapi import APIRouter, Header, Query, HTTPException from fastapi.responses import JSONResponse @@ -38,6 +42,7 @@ check_model_connectivity, verify_model_config_connectivity, ) +from services.model_capacity_suggestion_service import suggest_capacity from services.model_management_service import ( create_model_for_tenant, create_provider_models_for_tenant, @@ -49,6 +54,7 @@ list_models_for_tenant, list_llm_models_for_tenant, list_models_for_admin, + get_capacity_coverage, ) from utils.auth_utils import get_current_user_id @@ -57,6 +63,59 @@ logger = logging.getLogger("model_management_app") +def _capacity_suggestion_response_to_model(result) -> ModelCapacitySuggestionResponse: + suggestions = None + if result.suggestions is not None: + suggestions = CapacitySuggestionFields( + context_window_tokens=result.suggestions.context_window_tokens, + max_input_tokens=result.suggestions.max_input_tokens, + max_output_tokens=result.suggestions.max_output_tokens, + default_output_reserve_tokens=result.suggestions.default_output_reserve_tokens, + tokenizer_family=result.suggestions.tokenizer_family, + ) + + return ModelCapacitySuggestionResponse( + suggestions=suggestions, + match_kind=result.match_kind.value, + match_confidence=result.match_confidence.value if result.match_confidence else None, + match_explanation=result.match_explanation, + suggested_provider=result.suggested_provider, + canonical_model_name=result.canonical_model_name, + capability_profile_version=result.capability_profile_version, + capacity_source_on_accept=result.capacity_source_on_accept, + ) + + +def _suggest_capacity_for_request(request: ModelCapacitySuggestionRequest) -> ModelCapacitySuggestionResponse: + result = suggest_capacity( + model_name=request.model_name, + base_url=request.base_url, + provider_hint=request.provider_hint, + model_type=request.model_type, + api_key=request.api_key, + enabled=CAPACITY_SUGGESTION_ENABLED, + ) + return _capacity_suggestion_response_to_model(result) + + +def _capacity_suggestion_for_model_request(request: ModelRequest): + if not CAPACITY_SUGGESTION_ENABLED: + return None + + try: + suggestion_request = ModelCapacitySuggestionRequest( + model_name=request.model_name, + base_url=request.base_url, + provider_hint=request.model_factory, + api_key=request.api_key, + model_type=request.model_type, + ) + return _suggest_capacity_for_request(suggestion_request).model_dump() + except ValueError as exc: + logger.debug("Capacity suggestion unavailable for connectivity request: %s", exc) + return None + + @router.post("/create") async def create_model(request: ModelRequest, authorization: Optional[str] = Header(None)): """Create a single model record for the current tenant. @@ -90,6 +149,57 @@ async def create_model(request: ModelRequest, authorization: Optional[str] = Hea status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e)) +@router.post("/suggest-capacity") +async def suggest_model_capacity( + request: ModelCapacitySuggestionRequest, + authorization: Optional[str] = Header(None), +): + """Return a non-mutating capacity suggestion for a model add/edit form. + + Response uses the shared `/model/*` envelope ({message, data}) so the + frontend service layer can unwrap it the same way as every other + `/model/*` route. Returning the bare Pydantic model broke the dialog + and coverage-banner integrations because the frontend reads + `result.data` unconditionally. + """ + try: + get_current_user_id(authorization) + result = _suggest_capacity_for_request(request) + return JSONResponse(status_code=HTTPStatus.OK, content={ + "message": "Successfully suggested model capacity", + "data": jsonable_encoder(result), + }) + except ValueError as e: + logging.error(f"Invalid capacity suggestion request: {str(e)}") + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail=str(e)) + except HTTPException: + raise + except Exception as e: + logging.error(f"Failed to suggest model capacity: {str(e)}") + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e)) + + +@router.get("/capacity-coverage") +async def get_model_capacity_coverage(authorization: Optional[str] = Header(None)): + """Return bare-capacity LLM/VLM coverage for the current tenant. + + Wrapped in the shared `{message, data}` envelope; see + `suggest_model_capacity` for the same rationale. + """ + try: + _, tenant_id = get_current_user_id(authorization) + result = get_capacity_coverage(tenant_id) + return JSONResponse(status_code=HTTPStatus.OK, content={ + "message": "Successfully retrieved model capacity coverage", + "data": jsonable_encoder(result), + }) + except HTTPException: + raise + except Exception as e: + logging.error(f"Failed to get model capacity coverage: {str(e)}") + raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e)) + + @router.post("/provider/create") async def create_provider_model(request: ProviderModelRequest, authorization: Optional[str] = Header(None)): """Create or refresh provider models for the current tenant in memory only. @@ -338,6 +448,11 @@ async def check_temporary_model_health(request: ModelRequest): """ try: result = await verify_model_config_connectivity(request.model_dump()) + result["capacity_suggestion"] = ( + _capacity_suggestion_for_model_request(request) + if result.get("connectivity") is True + else None + ) return JSONResponse(status_code=HTTPStatus.OK, content={ "message": "Successfully verified model connectivity", "data": result diff --git a/backend/consts/capability_profiles.py b/backend/consts/capability_profiles.py new file mode 100644 index 000000000..d6f30f4dd --- /dev/null +++ b/backend/consts/capability_profiles.py @@ -0,0 +1,162 @@ +"""Day-one capability profile catalog for ModelCapacityResolver. + +Source of truth: W1 ADR at +`doc/working/context-management-workstreams/W1_ADR_Capability_Catalog_Storage_and_Fingerprint.md`. + +This module owns the approved catalog data. The SDK resolver +(`sdk/nexent/core/models/capacity_resolver.py`) takes the catalog as a parameter; +it does not import this module directly. Backend services read CATALOG here and +pass it through to the resolver. + +Changes to entries: bump the per-entry `capability_profile_version` integer +suffix AND `CATALOG_REVISION` in one PR. Numerical values must be re-verified +against provider documentation at PR merge time. +""" +from __future__ import annotations + +import logging +from typing import Dict + +from nexent.core.models.capacity_resolver import CapabilityProfile, ProfileKey + +logger = logging.getLogger(__name__) + + +CATALOG_REVISION = "2026-06-23.4" + + +CATALOG: Dict[ProfileKey, CapabilityProfile] = { + ("openai", "gpt-4o"): CapabilityProfile( + provider="openai", + model_name="gpt-4o", + capability_profile_version="openai/gpt-4o@1", + window_shape="combined", + context_window_tokens=128_000, + max_output_tokens=16_384, + default_output_reserve_tokens=4_096, + tokenizer_family="o200k_base", + ), + ("openai", "gpt-4.1"): CapabilityProfile( + provider="openai", + model_name="gpt-4.1", + capability_profile_version="openai/gpt-4.1@1", + window_shape="combined", + context_window_tokens=1_000_000, + max_output_tokens=32_768, + default_output_reserve_tokens=8_192, + tokenizer_family="o200k_base", + ), + ("dashscope", "qwen-plus"): CapabilityProfile( + provider="dashscope", + model_name="qwen-plus", + capability_profile_version="dashscope/qwen-plus@1", + window_shape="combined", + context_window_tokens=131_072, + max_output_tokens=16_384, + default_output_reserve_tokens=4_096, + tokenizer_family="qwen", + ), + ("dashscope", "qwen-turbo"): CapabilityProfile( + provider="dashscope", + model_name="qwen-turbo", + capability_profile_version="dashscope/qwen-turbo@1", + window_shape="combined", + context_window_tokens=1_000_000, + max_output_tokens=16_384, + default_output_reserve_tokens=4_096, + tokenizer_family="qwen", + ), + # Sources cross-checked 2026-06-23: + # https://help.aliyun.com/zh/model-studio/models (Bailian model catalog) + # https://llm-stats.com/models/qwen3.7-max (1.0M input, 65.5K output) + ("dashscope", "qwen3.7-max"): CapabilityProfile( + provider="dashscope", + model_name="qwen3.7-max", + capability_profile_version="dashscope/qwen3.7-max@1", + window_shape="combined", + context_window_tokens=1_000_000, + max_output_tokens=65_536, + default_output_reserve_tokens=8_192, + tokenizer_family="qwen", + ), + ("dashscope", "glm-5.1"): CapabilityProfile( + provider="dashscope", + model_name="glm-5.1", + capability_profile_version="dashscope/glm-5.1@1", + window_shape="combined", + context_window_tokens=200_000, + max_output_tokens=131_072, + default_output_reserve_tokens=8_192, + tokenizer_family="chatglm", + ), + ("silicon", "Qwen/Qwen3.6-27B"): CapabilityProfile( + provider="silicon", + model_name="Qwen/Qwen3.6-27B", + capability_profile_version="silicon/qwen3.6-27b@1", + window_shape="combined", + context_window_tokens=262_144, + max_output_tokens=65_536, + default_output_reserve_tokens=8_192, + tokenizer_family="qwen", + ), + ("silicon", "Pro/moonshotai/Kimi-K2.6"): CapabilityProfile( + provider="silicon", + model_name="Pro/moonshotai/Kimi-K2.6", + capability_profile_version="silicon/kimi-k2.6@1", + window_shape="combined", + context_window_tokens=262_144, + max_output_tokens=131_072, + default_output_reserve_tokens=8_192, + tokenizer_family="moonshot", + ), + # DeepSeek official platform. Verified 2026-06-23 against + # https://api-docs.deepseek.com/zh-cn/quick_start/pricing + # (context 1M, max output 384K for both v4 models). Re-verify at PR + # merge time per the file header rule. + # + # `deepseek-chat` and `deepseek-reasoner` will be deprecated at + # 2026-07-24 23:59 (Beijing). Per DeepSeek docs they alias to + # `deepseek-v4-flash` non-thinking and thinking modes respectively, + # so their capacity profile mirrors `deepseek-v4-flash`. Remove these + # two entries after the deprecation date. + ("deepseek", "deepseek-chat"): CapabilityProfile( + provider="deepseek", + model_name="deepseek-chat", + capability_profile_version="deepseek/deepseek-chat@2", + window_shape="combined", + context_window_tokens=1_000_000, + max_output_tokens=384_000, + default_output_reserve_tokens=8_192, + tokenizer_family="deepseek", + ), + ("deepseek", "deepseek-reasoner"): CapabilityProfile( + provider="deepseek", + model_name="deepseek-reasoner", + capability_profile_version="deepseek/deepseek-reasoner@2", + window_shape="combined", + context_window_tokens=1_000_000, + max_output_tokens=384_000, + default_output_reserve_tokens=8_192, + tokenizer_family="deepseek", + ), + ("deepseek", "deepseek-v4-flash"): CapabilityProfile( + provider="deepseek", + model_name="deepseek-v4-flash", + capability_profile_version="deepseek/deepseek-v4-flash@1", + window_shape="combined", + context_window_tokens=1_000_000, + max_output_tokens=384_000, + default_output_reserve_tokens=8_192, + tokenizer_family="deepseek", + ), + ("deepseek", "deepseek-v4-pro"): CapabilityProfile( + provider="deepseek", + model_name="deepseek-v4-pro", + capability_profile_version="deepseek/deepseek-v4-pro@1", + window_shape="combined", + context_window_tokens=1_000_000, + max_output_tokens=384_000, + default_output_reserve_tokens=8_192, + tokenizer_family="deepseek", + ), +} diff --git a/backend/consts/const.py b/backend/consts/const.py index 574d550c0..11ca7f70e 100644 --- a/backend/consts/const.py +++ b/backend/consts/const.py @@ -168,6 +168,12 @@ class VectorDatabaseType(str, Enum): # Response flag when system prompts are withheld from non-ASSET_OWNER callers. AGENT_PROMPTS_HIDDEN_FLAG = "prompts_hidden" +# W11 capacity suggestion rollout flags. +CAPACITY_SUGGESTION_ENABLED = os.getenv( + "CAPACITY_SUGGESTION_ENABLED", "true").lower() in ("true", "1", "yes", "on") +CAPACITY_VISIBILITY_ENABLED = os.getenv( + "CAPACITY_VISIBILITY_ENABLED", "true").lower() in ("true", "1", "yes", "on") + # Deployment Version Configuration DEPLOYMENT_VERSION = os.getenv("DEPLOYMENT_VERSION", "speed") diff --git a/backend/consts/model.py b/backend/consts/model.py index 00e5b8a0a..39f577a98 100644 --- a/backend/consts/model.py +++ b/backend/consts/model.py @@ -138,6 +138,56 @@ class ModelRequest(BaseModel): access_token: Optional[str] = None timeout_seconds: Optional[int] = None concurrency_limit: Optional[int] = None + # W1 capacity fields (see W1 ADR). All nullable; resolver applies precedence. + context_window_tokens: Optional[int] = None + max_input_tokens: Optional[int] = None + max_output_tokens: Optional[int] = None + default_output_reserve_tokens: Optional[int] = None + tokenizer_family: Optional[str] = None + capacity_source: Optional[str] = None + capability_profile_version: Optional[str] = None + + +class CapacitySuggestionFields(BaseModel): + context_window_tokens: Optional[int] = None + max_input_tokens: Optional[int] = None + max_output_tokens: Optional[int] = None + default_output_reserve_tokens: Optional[int] = None + tokenizer_family: Optional[str] = None + + +class ModelCapacitySuggestionRequest(BaseModel): + model_name: str = Field(..., min_length=1, max_length=512) + base_url: Optional[str] = None + provider_hint: Optional[str] = None + api_key: Optional[str] = None + model_type: Optional[str] = None + + +class ModelCapacitySuggestionResponse(BaseModel): + suggestions: Optional[CapacitySuggestionFields] = None + match_kind: Literal["catalog_exact", "catalog_fuzzy", "provider_discovery", "none"] + match_confidence: Optional[Literal["high", "medium", "low"]] = None + match_explanation: str + suggested_provider: Optional[str] = None + canonical_model_name: Optional[str] = None + capability_profile_version: Optional[str] = None + capacity_source_on_accept: Optional[Literal["operator"]] = None + + +class CapacityCoverageBareModel(BaseModel): + model_id: int + model_name: str + model_factory: Optional[str] = None + model_type: Literal["llm", "vlm", "vlm2", "vlm3"] + max_tokens: Optional[int] = None + suggestion_available: bool = False + + +class CapacityCoverageResponse(BaseModel): + total_llm_vlm: int + bare_count: int + bare_models: List[CapacityCoverageBareModel] = Field(default_factory=list) class ProviderModelRequest(BaseModel): @@ -256,6 +306,7 @@ class AgentRequest(BaseModel): minio_files: Optional[List[Dict[str, Any]]] = None agent_id: Optional[int] = None model_id: Optional[int] = None + requested_output_tokens: Optional[int] = Field(default=None, gt=0) version_no: Optional[int] = None is_debug: Optional[bool] = False tool_params: Optional[ToolParamsRequest] = None @@ -492,6 +543,7 @@ class AgentInfoRequest(BaseModel): model_name: Optional[str] = None model_id: Optional[int] = None max_steps: Optional[int] = Field(default=None, ge=1, le=30) + requested_output_tokens: Optional[int] = Field(default=None, gt=0) provide_run_summary: Optional[bool] = None duty_prompt: Optional[str] = None constraint_prompt: Optional[str] = None @@ -591,6 +643,7 @@ class ExportAndImportAgentInfo(BaseModel): business_description: str author: Optional[str] = None max_steps: int + requested_output_tokens: Optional[int] = Field(default=None, gt=0) provide_run_summary: bool verification_config: Optional[Dict[str, Any]] = None duty_prompt: Optional[str] = None diff --git a/backend/database/agent_db.py b/backend/database/agent_db.py index 533659b0f..9bac87381 100644 --- a/backend/database/agent_db.py +++ b/backend/database/agent_db.py @@ -237,6 +237,7 @@ def create_agent(agent_info, tenant_id: str, user_id: str): "group_ids": new_agent.group_ids, "is_new": new_agent.is_new, "enable_context_manager": new_agent.enable_context_manager, + "requested_output_tokens": new_agent.requested_output_tokens, "verification_config": new_agent.verification_config, "greeting_message": new_agent.greeting_message, "example_questions": new_agent.example_questions, @@ -273,8 +274,13 @@ def update_agent(agent_id, agent_info, user_id, version_no: int = 0): if not agent: raise ValueError("ag_tenant_agent_t Agent not found") - for key, value in filter_property(agent_info.__dict__, AgentInfo).items(): - if value is None: + agent_data = dict(agent_info.__dict__) + fields_set = getattr(agent_info, "model_fields_set", None) + if fields_set is not None and "requested_output_tokens" not in fields_set: + agent_data.pop("requested_output_tokens", None) + + for key, value in filter_property(agent_data, AgentInfo).items(): + if value is None and key != "requested_output_tokens": continue if key == "group_ids": value = convert_list_to_string(value) diff --git a/backend/database/db_models.py b/backend/database/db_models.py index 5450b5f74..7aa56a6f0 100644 --- a/backend/database/db_models.py +++ b/backend/database/db_models.py @@ -188,6 +188,20 @@ class ModelRecord(TableBase): Integer, doc="Request timeout in seconds for this model. Default is 120 seconds.") concurrency_limit = Column( Integer, doc="Maximum concurrent requests for this model. Default is null (unlimited).") + context_window_tokens = Column( + Integer, doc="Total combined input/output context window in tokens, when the provider uses a combined window. Nullable.") + max_input_tokens = Column( + Integer, doc="Provider hard input-token limit when distinct from the combined window. Nullable.") + max_output_tokens = Column( + Integer, doc="Provider-supported or operator-configured completion-output cap. Replaces the ambiguous LLM meaning of max_tokens. Nullable.") + default_output_reserve_tokens = Column( + Integer, doc="Default output allowance reserved per request before constructing input context. Nullable.") + tokenizer_family = Column( + String(100), doc="Token-counting strategy or provider/model tokenizer identifier mapped via tokenizer_registry. Nullable.") + capacity_source = Column( + String(100), doc="Source of the persisted capacity value. Optional values: operator, profile, provider_candidate, legacy, unknown.") + capability_profile_version = Column( + String(100), doc="Version of the approved provider/model capability profile used by the request, e.g. openai/gpt-4o@1.") class ModelMonitoringRecord(SimpleTableBase): @@ -237,6 +251,69 @@ class ModelMonitoringRecord(SimpleTableBase): input_tokens = Column(Integer, doc="Number of input tokens") output_tokens = Column(Integer, doc="Number of output tokens") total_tokens = Column(Integer, doc="Total tokens (input + output)") + context_window_tokens = Column( + Integer, doc="Resolved total combined model context window for this request" + ) + default_output_reserve_tokens = Column( + Integer, doc="Default output allowance reserved before input context construction" + ) + capability_profile_version = Column( + String(100), doc="Version of the resolved capacity profile for this request" + ) + capacity_source = Column( + String(100), doc="Dominant source of resolved capacity fields for this request" + ) + requested_output_tokens = Column( + Integer, doc="Output tokens requested or reserved during capacity resolution" + ) + provider_input_limit_tokens = Column( + Integer, doc="Resolved provider input-token limit used by context management" + ) + tokenizer_family = Column( + String(100), doc="Tokenizer family used for request token counting" + ) + counting_mode = Column( + String(20), doc="Token counting mode for the request: exact or estimated" + ) + unknown_capabilities = Column( + JSONB, doc="Structured list of capacity capabilities unknown at resolution time" + ) + capacity_fingerprint = Column( + String(64), doc="Fingerprint of the resolved model capacity snapshot" + ) + budget_fingerprint = Column( + String(64), doc="Fingerprint of the resolved W2 safe input budget snapshot" + ) + budget_w1_fingerprint = Column( + String(64), doc="W1 capacity fingerprint consumed by the W2 budget snapshot" + ) + budget_requested_output_tokens = Column( + Integer, doc="W2 trusted requested output tokens used at dispatch" + ) + budget_output_reserve_source = Column( + String(32), doc="Source of the W2 requested output token reserve" + ) + budget_provider_input_limit_tokens = Column( + Integer, doc="Provider input limit after applying the W2 output reserve" + ) + budget_uncertainty_reserve_tokens = Column( + Integer, doc="Additional W2 uncertainty reserve deducted from input budget" + ) + budget_uncertainty_reserve_basis = Column( + String(64), doc="Basis used for the W2 uncertainty reserve" + ) + budget_soft_limit_ratio = Column( + Float, doc="W2 soft input budget ratio" + ) + budget_soft_input_budget_tokens = Column( + Integer, doc="W2 soft input budget where proactive compression begins" + ) + budget_hard_input_budget_tokens = Column( + Integer, doc="W2 hard input budget consumed by W3 final fit" + ) + budget_warnings = Column( + JSONB, doc="Structured W2 budget warnings active for this request" + ) generation_rate = Column( Float, doc="Token generation rate (tokens per second)") is_streaming = Column( @@ -333,6 +410,13 @@ class AgentInfo(TableBase): current_version_no = Column(Integer, nullable=True, doc="Current published version number. NULL means no version published yet") ingroup_permission = Column(String(30), doc="In-group permission: EDIT, READ_ONLY, PRIVATE") enable_context_manager = Column(Boolean, default=False, doc="Whether to enable context management (compression) for this agent") + requested_output_tokens = Column( + Integer, + doc=( + "Per-agent override for W2 requested_output_tokens. NULL means " + "inherit the resolved model-level default." + ), + ) verification_config = Column(JSONB, doc="Layered ReAct self-verification configuration") greeting_message = Column(Text, doc="Agent greeting message displayed on chat initial screen") example_questions = Column(JSONB, doc="List of example questions for starting a conversation with this agent") diff --git a/backend/services/agent_service.py b/backend/services/agent_service.py index 643d1995e..5ffc8bbcf 100644 --- a/backend/services/agent_service.py +++ b/backend/services/agent_service.py @@ -1109,6 +1109,7 @@ async def get_creating_sub_agent_info_impl(authorization: str = Header(None)): "model_name": agent_info["model_name"], "model_id": agent_info.get("model_id"), "max_steps": agent_info["max_steps"], + "requested_output_tokens": agent_info.get("requested_output_tokens"), "business_description": agent_info["business_description"], "duty_prompt": agent_info.get("duty_prompt"), "constraint_prompt": agent_info.get("constraint_prompt"), @@ -1116,12 +1117,52 @@ async def get_creating_sub_agent_info_impl(authorization: str = Header(None)): "sub_agent_id_list": query_sub_agents_id_list(main_agent_id=sub_agent_id, tenant_id=tenant_id)} +def _validate_requested_output_tokens_for_agent( + request: AgentInfoRequest, + tenant_id: str, +) -> None: + requested_output_tokens = request.requested_output_tokens + if requested_output_tokens is None: + return + + model_id = request.model_id + if model_id is None and request.agent_id is not None: + try: + existing_agent = search_agent_info_by_agent_id( + agent_id=request.agent_id, + tenant_id=tenant_id, + version_no=request.version_no, + ) + model_id = existing_agent.get("model_id") + except Exception as exc: + logger.warning( + "Could not resolve existing agent model for requested_output_tokens validation: %s", + exc, + ) + + if model_id is None: + return + + model_info = get_model_by_model_id(model_id, tenant_id=tenant_id) + max_output_tokens = model_info.get("max_output_tokens") if model_info else None + if max_output_tokens is not None and requested_output_tokens > max_output_tokens: + raise AppException( + ErrorCode.COMMON_PARAMETER_INVALID, + ( + "requested_output_tokens cannot exceed the selected model " + f"max_output_tokens ({max_output_tokens})" + ), + ) + + async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = Header(None)): user_id, tenant_id, _ = get_current_user_info(authorization) if request.example_questions is not None and len(request.example_questions) > 6: raise AppException(ErrorCode.COMMON_PARAMETER_INVALID, "example_questions cannot exceed 6 items") + _validate_requested_output_tokens_for_agent(request, tenant_id) + prompt_template_id, prompt_template_name = get_prompt_template_summary( template_id=request.prompt_template_id, tenant_id=tenant_id, @@ -1147,6 +1188,7 @@ async def update_agent_info_impl(request: AgentInfoRequest, authorization: str = "prompt_template_id": prompt_template_id, "prompt_template_name": prompt_template_name, "max_steps": request.max_steps, + "requested_output_tokens": request.requested_output_tokens, "provide_run_summary": request.provide_run_summary, "verification_config": request.verification_config, "duty_prompt": request.duty_prompt, @@ -1673,6 +1715,7 @@ async def export_agent_by_agent_id( business_description=agent_info["business_description"], author=agent_info.get("author"), max_steps=agent_info["max_steps"], + requested_output_tokens=agent_info.get("requested_output_tokens"), provide_run_summary=agent_info["provide_run_summary"], verification_config=agent_info.get("verification_config"), duty_prompt=agent_info.get( @@ -1828,6 +1871,7 @@ async def import_agent_by_agent_id( "prompt_template_id": import_agent_info.prompt_template_id or SYSTEM_PROMPT_TEMPLATE_ID, "prompt_template_name": import_agent_info.prompt_template_name or SYSTEM_PROMPT_TEMPLATE_NAME, "max_steps": import_agent_info.max_steps, + "requested_output_tokens": import_agent_info.requested_output_tokens, "provide_run_summary": import_agent_info.provide_run_summary, "verification_config": getattr(import_agent_info, "verification_config", None), "duty_prompt": import_agent_info.duty_prompt, @@ -2197,6 +2241,7 @@ async def prepare_agent_run( is_debug=agent_request.is_debug, override_version_no=agent_request.version_no, override_model_id=agent_request.model_id, + requested_output_tokens=agent_request.requested_output_tokens, tool_params=agent_request.tool_params, ) diff --git a/backend/services/model_capacity_suggestion_service.py b/backend/services/model_capacity_suggestion_service.py new file mode 100644 index 000000000..723f0fd8e --- /dev/null +++ b/backend/services/model_capacity_suggestion_service.py @@ -0,0 +1,292 @@ +import re +from dataclasses import dataclass +from enum import Enum +from typing import Any, Mapping, Optional + +from consts.const import CAPACITY_SUGGESTION_ENABLED + + +ProfileKey = tuple[str, str] +CapabilityProfileLike = Any + + +class CapacitySuggestionMatchKind(str, Enum): + CATALOG_EXACT = "catalog_exact" + CATALOG_FUZZY = "catalog_fuzzy" + PROVIDER_DISCOVERY = "provider_discovery" + NONE = "none" + + +class CapacitySuggestionConfidence(str, Enum): + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +@dataclass(frozen=True) +class CapacitySuggestionFields: + context_window_tokens: Optional[int] = None + max_input_tokens: Optional[int] = None + max_output_tokens: Optional[int] = None + default_output_reserve_tokens: Optional[int] = None + tokenizer_family: Optional[str] = None + + +@dataclass(frozen=True) +class CapacitySuggestionResult: + suggestions: Optional[CapacitySuggestionFields] + match_kind: CapacitySuggestionMatchKind + match_confidence: Optional[CapacitySuggestionConfidence] + match_explanation: str + suggested_provider: Optional[str] = None + canonical_model_name: Optional[str] = None + capability_profile_version: Optional[str] = None + capacity_source_on_accept: Optional[str] = None + + +# Substring patterns matched against the lower-cased base_url. Order matters: +# `in` returns the first hit, so place more-specific patterns before broader +# ones (e.g. `dashscope` before `aliyuncs`). Patterns mirror frontend +# PROVIDER_HINTS in `frontend/const/modelConfig.ts` so backend provider-by-URL +# detection stays consistent with the icon the user sees in the UI. +HOST_PROVIDER_PATTERNS = ( + ("dashscope", "dashscope"), + ("aliyuncs", "dashscope"), + ("siliconflow", "silicon"), + ("silicon", "silicon"), + ("modelengine", "modelengine"), + ("openai", "openai"), + ("deepseek", "deepseek"), + ("jina", "jina"), + ("tokenpony", "tokenpony"), + ("bytedance", "volcengine"), +) + +SUPPORTED_SUGGESTION_MODEL_TYPES = {"llm", "vlm", "vlm2", "vlm3"} + + +def pick_provider_from_base_url(base_url: Optional[str]) -> Optional[str]: + # Match the entire lower-cased base_url, mirroring the frontend + # detectProviderFromUrl helper. Substring `in` check, first hit wins. + if not base_url: + return None + + lowered = base_url.lower() + for pattern, provider in HOST_PROVIDER_PATTERNS: + if pattern in lowered: + return provider + return None + + +def _normalize_provider(provider: Optional[str]) -> Optional[str]: + if provider is None: + return None + normalized = provider.strip().lower() + if normalized in {"", "openai-api-compatible"}: + return None + if normalized == "siliconflow": + return "silicon" + return normalized + + +def normalize_model_name(model_name: str) -> str: + return re.sub(r"[-_./\s]+", "", model_name.strip().lower()) + + +def _normalize_catalog_exact_name(model_name: str) -> str: + return model_name.strip().lower() + + +def _profile_to_suggestion(profile: CapabilityProfileLike) -> CapacitySuggestionFields: + return CapacitySuggestionFields( + context_window_tokens=profile.context_window_tokens, + max_input_tokens=profile.max_input_tokens, + max_output_tokens=profile.max_output_tokens, + default_output_reserve_tokens=profile.default_output_reserve_tokens, + tokenizer_family=profile.tokenizer_family, + ) + + +def _result_from_profile( + provider: str, + model_name: str, + profile: CapabilityProfileLike, + match_kind: CapacitySuggestionMatchKind, +) -> CapacitySuggestionResult: + confidence = ( + CapacitySuggestionConfidence.HIGH + if match_kind == CapacitySuggestionMatchKind.CATALOG_EXACT + else CapacitySuggestionConfidence.MEDIUM + ) + return CapacitySuggestionResult( + suggestions=_profile_to_suggestion(profile), + match_kind=match_kind, + match_confidence=confidence, + match_explanation=f"Matched approved catalog profile {profile.capability_profile_version}", + suggested_provider=provider, + canonical_model_name=model_name, + capability_profile_version=profile.capability_profile_version, + capacity_source_on_accept="operator", + ) + + +def _none_result(explanation: str) -> CapacitySuggestionResult: + return CapacitySuggestionResult( + suggestions=None, + match_kind=CapacitySuggestionMatchKind.NONE, + match_confidence=None, + match_explanation=explanation, + ) + + +def _provider_catalog( + catalog: Mapping[ProfileKey, CapabilityProfileLike], + provider: str, +) -> dict[ProfileKey, CapabilityProfileLike]: + return { + (catalog_provider, catalog_model): profile + for (catalog_provider, catalog_model), profile in catalog.items() + if catalog_provider == provider + } + + +def _unique_final_segment_match( + model_name: str, + catalog: Mapping[ProfileKey, CapabilityProfileLike], + provider: str, +) -> Optional[tuple[ProfileKey, CapabilityProfileLike]]: + requested = normalize_model_name(model_name) + matches: list[tuple[ProfileKey, CapabilityProfileLike]] = [] + for key, profile in _provider_catalog(catalog, provider).items(): + catalog_model = key[1] + final_segment = catalog_model.split("/")[-1] + if normalize_model_name(final_segment) == requested: + matches.append((key, profile)) + + if len(matches) == 1: + return matches[0] + return None + + +def _fuzzy_catalog_match( + model_name: str, + catalog: Mapping[ProfileKey, CapabilityProfileLike], + provider: str, +) -> Optional[tuple[ProfileKey, CapabilityProfileLike]]: + requested = normalize_model_name(model_name) + matches: list[tuple[ProfileKey, CapabilityProfileLike]] = [] + for key, profile in _provider_catalog(catalog, provider).items(): + if normalize_model_name(key[1]) == requested: + matches.append((key, profile)) + + if len(matches) == 1: + return matches[0] + + return _unique_final_segment_match(model_name, catalog, provider) + + +def _unique_catalog_provider_for_model( + model_name: str, + catalog: Mapping[ProfileKey, CapabilityProfileLike], +) -> Optional[str]: + requested = normalize_model_name(model_name) + providers = { + provider + for provider, catalog_model in catalog.keys() + if normalize_model_name(catalog_model) == requested + or normalize_model_name(catalog_model.split("/")[-1]) == requested + } + if len(providers) == 1: + return next(iter(providers)) + return None + + +def pick_provider( + provider_hint: Optional[str], + base_url: Optional[str], + model_name: str, + catalog: Optional[Mapping[ProfileKey, CapabilityProfileLike]] = None, +) -> Optional[str]: + active_catalog = catalog if catalog is not None else _get_default_catalog() + explicit_provider = _normalize_provider(provider_hint) + if explicit_provider: + return explicit_provider + + inferred_provider = pick_provider_from_base_url(base_url) + if inferred_provider: + return inferred_provider + + return _unique_catalog_provider_for_model(model_name, active_catalog) + + +def _get_default_catalog() -> Mapping[ProfileKey, CapabilityProfileLike]: + from consts.capability_profiles import CATALOG + + return CATALOG + + +def suggest_capacity( + model_name: str, + base_url: Optional[str] = None, + provider_hint: Optional[str] = None, + model_type: Optional[str] = None, + api_key: Optional[str] = None, + catalog: Optional[Mapping[ProfileKey, CapabilityProfileLike]] = None, + enabled: bool = CAPACITY_SUGGESTION_ENABLED, +) -> CapacitySuggestionResult: + del api_key + + if not enabled: + return _none_result("Capacity suggestion is disabled") + + clean_model_name = (model_name or "").strip() + if not clean_model_name: + raise ValueError("model_name is required") + + if len(clean_model_name) > 512: + raise ValueError("model_name is too long") + + if model_type and model_type.lower() not in SUPPORTED_SUGGESTION_MODEL_TYPES: + return _none_result(f"Capacity suggestion is not supported for model_type={model_type}") + + active_catalog = catalog if catalog is not None else _get_default_catalog() + + provider = pick_provider(provider_hint, base_url, clean_model_name, active_catalog) + if not provider: + return _none_result("No provider candidate could be inferred") + + exact_key = (provider, clean_model_name) + exact_profile = active_catalog.get(exact_key) + if exact_profile: + return _result_from_profile( + provider, + clean_model_name, + exact_profile, + CapacitySuggestionMatchKind.CATALOG_EXACT, + ) + + normalized_exact_key = None + for catalog_key in _provider_catalog(active_catalog, provider).keys(): + if _normalize_catalog_exact_name(catalog_key[1]) == _normalize_catalog_exact_name(clean_model_name): + normalized_exact_key = catalog_key + break + + if normalized_exact_key: + return _result_from_profile( + normalized_exact_key[0], + normalized_exact_key[1], + active_catalog[normalized_exact_key], + CapacitySuggestionMatchKind.CATALOG_EXACT, + ) + + fuzzy_match = _fuzzy_catalog_match(clean_model_name, active_catalog, provider) + if fuzzy_match: + fuzzy_key, profile = fuzzy_match + return _result_from_profile( + fuzzy_key[0], + fuzzy_key[1], + profile, + CapacitySuggestionMatchKind.CATALOG_FUZZY, + ) + + return _none_result(f"No approved catalog profile matched provider={provider}, model={clean_model_name}") diff --git a/backend/services/model_health_service.py b/backend/services/model_health_service.py index 2dc276aeb..35fff2a23 100644 --- a/backend/services/model_health_service.py +++ b/backend/services/model_health_service.py @@ -38,13 +38,17 @@ def _normalize_embedding_url(base_url: str) -> str: def _infer_model_factory(model_type: str, base_url: str, current_factory: Optional[str] = None) -> Optional[str]: """Infer model_factory from base_url if not already set or is generic. - Currently handles: - - multi_embedding with dashscope URL -> "dashscope" - - embedding with dashscope URL -> "dashscope" (uses OpenAI-compatible endpoint) + Uses the shared W11 host map so embedding and LLM/VLM inference do not drift. """ - base_url_lower = base_url.lower() - if "dashscope" in base_url_lower: - return DASHSCOPE_MODEL_FACTORY + try: + from services.model_capacity_suggestion_service import pick_provider_from_base_url + + inferred_provider = pick_provider_from_base_url(base_url) + except Exception: + inferred_provider = DASHSCOPE_MODEL_FACTORY if "dashscope" in base_url.lower() else None + + if inferred_provider: + return inferred_provider return current_factory diff --git a/backend/services/model_management_service.py b/backend/services/model_management_service.py index 1511a9301..d4d18a818 100644 --- a/backend/services/model_management_service.py +++ b/backend/services/model_management_service.py @@ -1,7 +1,13 @@ import logging from typing import List, Dict, Any, Optional -from consts.const import LOCALHOST_IP, LOCALHOST_NAME, DOCKER_INTERNAL_HOST +from consts.const import ( + CAPACITY_SUGGESTION_ENABLED, + CAPACITY_VISIBILITY_ENABLED, + LOCALHOST_IP, + LOCALHOST_NAME, + DOCKER_INTERNAL_HOST, +) from consts.model import ModelConnectStatusEnum from consts.provider import ( ProviderEnum, @@ -26,6 +32,7 @@ get_provider_models, ) from services.model_health_service import embedding_dimension_check, _infer_model_factory +from services.model_capacity_suggestion_service import CapacitySuggestionMatchKind, suggest_capacity from utils.model_name_utils import ( add_repo_to_name, split_repo_name, @@ -38,6 +45,49 @@ logger = logging.getLogger("model_management_service") INDEPENDENT_MULTIMODAL_MODEL_TYPES = {"vlm", "vlm2", "vlm3"} +CAPACITY_COVERAGE_MODEL_TYPES = {"llm", "vlm", "vlm2", "vlm3"} + + +# OpenTelemetry counter for silent catalog-matcher failures during the +# capacity-coverage scan. The matcher is called per row so we cannot raise -- +# but the silent fallback to suggestion_available=False would hide a corrupt +# catalog entry that turns every "available" hint into "false" across a whole +# tenant. The counter gives staging/CI a single number to watch. +# +# Guarded the same way as the SDK monitor module: if OpenTelemetry is not +# installed (some deployments run without it), the counter is None and the +# increment becomes a no-op. +try: + from opentelemetry import metrics as _otel_metrics + + _capacity_suggestion_meter = _otel_metrics.get_meter(__name__) + _capacity_suggestion_coverage_errors_total = _capacity_suggestion_meter.create_counter( + name="model_capacity_suggestion_coverage_errors_total", + description=( + "Count of catalog-matcher exceptions raised while computing the " + "per-row `suggestion_available` flag in /model/capacity-coverage. " + "Non-zero means catalog data or matcher logic is broken; " + "operators see every row as suggestion_available=False." + ), + unit="errors", + ) +except Exception: # pragma: no cover - OTel is optional at runtime + _capacity_suggestion_coverage_errors_total = None + + +def _record_capacity_coverage_error(model_id: Optional[Any], exc: Exception) -> None: + if _capacity_suggestion_coverage_errors_total is None: + return + try: + _capacity_suggestion_coverage_errors_total.add( + 1, + { + "model_id": str(model_id) if model_id is not None else "unknown", + "error_type": type(exc).__name__, + }, + ) + except Exception: # pragma: no cover - never break coverage for telemetry + pass def _has_display_name_conflict(existing_models: List[Dict[str, Any]], model_type: Optional[str]) -> bool: @@ -55,6 +105,92 @@ def _has_display_name_conflict(existing_models: List[Dict[str, Any]], model_type return True +def _coerce_legacy_max_tokens_alias(model_data: Dict[str, Any]) -> None: + """Keep the deprecated `max_tokens` column in lockstep with `max_output_tokens`. + + W1 step 7 deprecates `max_tokens` as the LLM/VLM output-cap alias of + `max_output_tokens`. Legacy clients that still write `max_tokens` + independently let the two columns diverge in the DB; that divergence + later surfaces at the W2 dispatch boundary as + `CallerMaxTokensOverrideForbidden` because the SDK auto-fills + `max_tokens` from the model record while the W2 snapshot computes its + output cap from `max_output_tokens`. + + Defense in depth at the service layer: when a caller sends a non-None + `max_output_tokens`, force `max_tokens` to mirror it. Embedding rows are + exempt because they repurpose `max_tokens` as the vector dimension. + """ + max_output = model_data.get("max_output_tokens") + if max_output is None: + return + if model_data.get("model_type") in ("embedding", "multi_embedding"): + return + model_data["max_tokens"] = max_output + + +def _is_bare_capacity_model(model: Dict[str, Any]) -> bool: + return model.get("context_window_tokens") is None or model.get("max_output_tokens") is None + + +def _capacity_suggestion_available(model: Dict[str, Any]) -> bool: + if not CAPACITY_SUGGESTION_ENABLED: + return False + + try: + model_name = add_repo_to_name(model.get("model_repo", ""), model.get("model_name", "")) + result = suggest_capacity( + model_name=model_name, + base_url=model.get("base_url"), + provider_hint=model.get("model_factory"), + model_type=model.get("model_type"), + enabled=CAPACITY_SUGGESTION_ENABLED, + ) + return result.match_kind != CapacitySuggestionMatchKind.NONE + except Exception as exc: + # A catalog-matcher exception must not break /capacity-coverage -- + # the endpoint scans every LLM/VLM row, and one bad row would make + # the whole tenant view explode. We fall back to False and emit a + # counter so a corrupt catalog is visible in metrics instead of + # silently turning every row into "no suggestion available". + logger.debug("Capacity coverage suggestion check failed for model_id=%s: %s", model.get("model_id"), exc) + _record_capacity_coverage_error(model.get("model_id"), exc) + return False + + +def get_capacity_coverage(tenant_id: str) -> Dict[str, Any]: + """Return bare-capacity LLM/VLM coverage for one tenant.""" + if not CAPACITY_VISIBILITY_ENABLED: + return { + "total_llm_vlm": 0, + "bare_count": 0, + "bare_models": [], + } + + records = get_model_records(None, tenant_id) + scoped_records = [ + model for model in records + if model.get("model_type") in CAPACITY_COVERAGE_MODEL_TYPES + ] + bare_models = [ + { + "model_id": model["model_id"], + "model_name": add_repo_to_name(model.get("model_repo", ""), model.get("model_name", "")), + "model_factory": model.get("model_factory"), + "model_type": model.get("model_type"), + "max_tokens": model.get("max_tokens"), + "suggestion_available": _capacity_suggestion_available(model), + } + for model in scoped_records + if _is_bare_capacity_model(model) + ] + + return { + "total_llm_vlm": len(scoped_records), + "bare_count": len(bare_models), + "bare_models": bare_models, + } + + async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict[str, Any]): """Create a single model record for the given tenant. @@ -93,6 +229,8 @@ async def create_model_for_tenant(user_id: str, tenant_id: str, model_data: Dict model_name=model_data.get("model_name", "") ) + _coerce_legacy_max_tokens_alias(model_data) + # Use NOT_DETECTED status as default model_data["connect_status"] = model_data.get( "connect_status") or ModelConnectStatusEnum.NOT_DETECTED.value @@ -208,9 +346,24 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay for model in existing_model_list } - # Delete existing models not present + # Delete existing models not present. + # The membership key MUST match how existing_model_map (a few lines + # above) and the create-or-update branch (a few lines below) build + # their lookup key, otherwise the two halves disagree about what + # "the same model" means. Both of those use add_repo_to_name, which + # omits the slash when model_repo is empty. The naive + # `model_repo + "/" + model_name` here always prepends "/" for the + # empty-repo case (DashScope catalogs return bare names like + # "glm-4.7" and rows land with model_repo=""), so "/glm-4.7" never + # matched the catalog's "glm-4.7" entry -- every existing row was + # treated as "not in the incoming list" and silently soft-deleted on + # every batch_create. Use the same helper to keep both halves + # speaking the same language. for model in existing_model_list: - model_full_name = model["model_repo"] + "/" + model["model_name"] + model_full_name = add_repo_to_name( + model_repo=model["model_repo"], + model_name=model["model_name"], + ) if model_full_name not in model_list_ids: delete_model_record(model["model_id"], user_id, tenant_id) @@ -231,6 +384,31 @@ async def batch_create_models_for_tenant(user_id: str, tenant_id: str, batch_pay new_max_tokens = model.get("max_tokens") if new_max_tokens is not None and existing_max_tokens != new_max_tokens: update_data["max_tokens"] = new_max_tokens + # Same gap as prepare_model_dict had for the create branch: + # the batch refresh path only touched legacy max_tokens, so + # editing a row's capacity via batch-add (e.g. tweaking the + # top-level batch defaults and re-confirming) silently + # dropped the W1/W2 capacity updates. We mirror the + # operator-vs-candidate rule from prepare_model_dict here: + # only persist W1/W2 capacity when the payload is marked + # capacity_source="operator", so provider-discovered hints + # don't auto-overwrite an existing row on a refresh. + if model.get("capacity_source") == "operator": + for field in ( + "context_window_tokens", + "max_input_tokens", + "max_output_tokens", + "default_output_reserve_tokens", + "tokenizer_family", + "capability_profile_version", + ): + new_value = model.get(field) + if new_value is None: + continue + if existing_model.get(field) != new_value: + update_data[field] = new_value + if existing_model.get("capacity_source") != "operator": + update_data["capacity_source"] = "operator" if update_data: update_model_record(existing_model["model_id"], update_data, user_id) continue @@ -315,6 +493,16 @@ async def update_single_model_for_tenant( else: model_data["ssl_verify"] = True + # Carry model_type from the existing record so the legacy-alias + # coercion can distinguish LLM/VLM updates from embedding updates + # even when the caller payload omits model_type. We don't store the + # injected model_type back on model_data because the update path + # explicitly strips it later. + existing_model_type = existing_models[0].get("model_type") if existing_models else None + if model_data.get("max_output_tokens") is not None and \ + existing_model_type not in ("embedding", "multi_embedding"): + model_data["max_tokens"] = model_data["max_output_tokens"] + if has_multi_embedding: # Update both embedding and multi_embedding records for model in existing_models: @@ -343,6 +531,7 @@ async def batch_update_models_for_tenant(user_id: str, tenant_id: str, model_lis """Batch update models for a tenant by model_id or model_name.""" try: for model in model_list: + _coerce_legacy_max_tokens_alias(model) # Build update data excluding id fields update_data = {k: v for k, v in model.items() if k not in ["model_id", "model_name"]} @@ -571,4 +760,3 @@ async def list_models_for_admin( except Exception as e: logging.error(f"Failed to retrieve admin model list: {str(e)}") raise Exception(f"Failed to retrieve admin model list: {str(e)}") - diff --git a/backend/services/model_provider_service.py b/backend/services/model_provider_service.py index 1aa89fa3b..31867bedc 100644 --- a/backend/services/model_provider_service.py +++ b/backend/services/model_provider_service.py @@ -108,6 +108,35 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a "max_tokens", 0) if not is_embedding_type else 0 timeout_seconds_value = 120 if not is_embedding_type else None + # W1/W2 capacity fields. The frontend batch-add resolves these in + # buildBatchModelData (row override -> top-level batch default) and + # sends them per row tagged with capacity_source. Two cases: + # - capacity_source="operator": the operator explicitly saved these + # values (top-level batch default panel or per-row gear modal). + # Persist them. Without this branch the ModelRequest defaults kick + # in (all None) and every freshly batch-created row lands with + # context_window_tokens=NULL, max_output_tokens=NULL even though + # the user filled the panel -- the glm-5.1/glm-5.2 incident. + # - capacity_source="provider_candidate" (or anything else): per the + # W1 design these are advisory UI hints surfaced from the catalog + # by _extract_capacity_hints. They are shown to the user as + # suggestions but not auto-persisted; only operator acceptance + # should write them. + is_operator_capacity = model.get("capacity_source") == "operator" + capacity_kwargs = ( + { + "context_window_tokens": model.get("context_window_tokens"), + "max_input_tokens": model.get("max_input_tokens"), + "max_output_tokens": model.get("max_output_tokens"), + "default_output_reserve_tokens": model.get("default_output_reserve_tokens"), + "tokenizer_family": model.get("tokenizer_family"), + "capacity_source": "operator", + "capability_profile_version": model.get("capability_profile_version"), + } + if is_operator_capacity + else {} + ) + model_obj = ModelRequest( model_factory=provider, model_name=model_name, @@ -118,7 +147,8 @@ async def prepare_model_dict(provider: str, model: dict, model_url: str, model_a expected_chunk_size=expected_chunk_size, maximum_chunk_size=maximum_chunk_size, chunk_batch=chunk_batch, - timeout_seconds=timeout_seconds_value + timeout_seconds=timeout_seconds_value, + **capacity_kwargs, ) model_dict = model_obj.model_dump() @@ -194,11 +224,20 @@ def merge_existing_model_attributes( if not model_list or not existing_model_list: return model_list - # Create a mapping table for existing models for quick lookup + # Create a mapping table for existing models for quick lookup. + # Use add_repo_to_name so the lookup key matches the format used by + # provider responses and downstream consumers. Naive `model_repo + "/" + + # model_name` prepends a leading slash when model_repo is empty + # (DashScope-style bare names like "glm-4.7" land with model_repo=""), + # so "/glm-4.7" never matches the catalog's "glm-4.7" entry and the + # merge silently no-ops -- the same wire-key bug fixed in + # batch_create_models_for_tenant's delete loop. existing_model_map = {} for existing_model in existing_model_list: - model_full_name = existing_model["model_repo"] + \ - "/" + existing_model["model_name"] + model_full_name = add_repo_to_name( + model_repo=existing_model["model_repo"], + model_name=existing_model["model_name"], + ) existing_model_map[model_full_name] = existing_model # Iterate through the model list, merge specified fields from existing models diff --git a/backend/services/providers/base.py b/backend/services/providers/base.py index 4756bf6ad..0b0576765 100644 --- a/backend/services/providers/base.py +++ b/backend/services/providers/base.py @@ -1,12 +1,95 @@ import logging from abc import ABC, abstractmethod -from typing import Dict, List +from typing import Any, Dict, Iterable, List import aiohttp logger = logging.getLogger("model_provider") +_CONTEXT_WINDOW_KEYS = ( + "context_window_tokens", + "context_window", + "context_length", + "max_context_length", + "max_context_tokens", + "max_sequence_length", +) +_MAX_INPUT_KEYS = ("max_input_tokens", "input_token_limit", "max_prompt_tokens") +_MAX_OUTPUT_KEYS = ( + "max_output_tokens", + "output_token_limit", + "max_completion_tokens", + "max_tokens", +) +_OUTPUT_RESERVE_KEYS = ( + "default_output_reserve_tokens", + "default_output_reserve", + "output_reserve_tokens", +) +_TOKENIZER_KEYS = ("tokenizer_family", "tokenizer", "tokenizer_type") + + +def _positive_int(value: Any) -> int | None: + if isinstance(value, bool) or value is None: + return None + try: + parsed = int(value) + except (TypeError, ValueError): + return None + return parsed if parsed > 0 else None + + +def _candidate_dicts(raw: Dict, nested_keys: Iterable[str]) -> List[Dict]: + candidates = [raw] + for key in nested_keys: + value = raw.get(key) + if isinstance(value, dict): + candidates.append(value) + return candidates + + +def _first_positive_int(candidates: List[Dict], keys: tuple[str, ...]) -> int | None: + for candidate in candidates: + for key in keys: + value = _positive_int(candidate.get(key)) + if value is not None: + return value + return None + + +def _first_non_empty_str(candidates: List[Dict], keys: tuple[str, ...]) -> str | None: + for candidate in candidates: + for key in keys: + value = candidate.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return None + + +def _extract_capacity_hints_from_raw(raw: Dict, nested_keys: Iterable[str] = ()) -> Dict: + """Extract advisory provider-discovery capacity hints from one raw model row.""" + candidates = _candidate_dicts(raw, nested_keys) + hints = {} + for target_key, source_keys in ( + ("context_window_tokens", _CONTEXT_WINDOW_KEYS), + ("max_input_tokens", _MAX_INPUT_KEYS), + ("max_output_tokens", _MAX_OUTPUT_KEYS), + ("default_output_reserve_tokens", _OUTPUT_RESERVE_KEYS), + ): + value = _first_positive_int(candidates, source_keys) + if value is not None: + hints[target_key] = value + + tokenizer_family = _first_non_empty_str(candidates, _TOKENIZER_KEYS) + if tokenizer_family: + hints["tokenizer_family"] = tokenizer_family + + if hints: + hints["capacity_source"] = "provider_candidate" + return hints + + # ============================================================================= # Provider Error Handling Utilities # ============================================================================= diff --git a/backend/services/providers/dashscope_provider.py b/backend/services/providers/dashscope_provider.py index 497dcfe99..f78c57a3f 100644 --- a/backend/services/providers/dashscope_provider.py +++ b/backend/services/providers/dashscope_provider.py @@ -3,7 +3,11 @@ import asyncio from consts.const import DEFAULT_LLM_MAX_TOKENS from consts.provider import DASHSCOPE_GET_URL -from services.providers.base import AbstractModelProvider, _classify_provider_error +from services.providers.base import ( + AbstractModelProvider, + _classify_provider_error, + _extract_capacity_hints_from_raw, +) DASHSCOPE_IMAGE_GENERATION_KEYWORDS = ( @@ -33,6 +37,10 @@ DASHSCOPE_VIDEO_UNDERSTANDING_KEYWORDS = ("omni", "video-understanding", "video-ocr") +def _extract_capacity_hints(raw: Dict) -> Dict: + return _extract_capacity_hints_from_raw(raw, nested_keys=("inference_metadata",)) + + def _modality_set(value) -> set: if not value: return set() @@ -155,6 +163,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: "model_type": "", "max_tokens": DEFAULT_LLM_MAX_TOKENS } + cleaned_model.update(_extract_capacity_hints(model_obj)) # 1. Embedding if 'embedding' in m_id.lower() or '向量' in desc: cleaned_model.update({"model_tag": "embedding", "model_type": "embedding"}) @@ -214,4 +223,3 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: return [] except (httpx.HTTPStatusError, httpx.ConnectTimeout, httpx.ConnectError, Exception) as e: return _classify_provider_error("DashScope", exception=e) - diff --git a/backend/services/providers/modelengine_provider.py b/backend/services/providers/modelengine_provider.py index 276f84378..5b0e2b555 100644 --- a/backend/services/providers/modelengine_provider.py +++ b/backend/services/providers/modelengine_provider.py @@ -4,13 +4,21 @@ import aiohttp from consts.const import DEFAULT_LLM_MAX_TOKENS -from services.providers.base import AbstractModelProvider, _classify_provider_error +from services.providers.base import ( + AbstractModelProvider, + _classify_provider_error, + _extract_capacity_hints_from_raw, +) logger = logging.getLogger("model_provider") MODEL_ENGINE_NORTH_PREFIX = "open/router/v1" +def _extract_capacity_hints(raw: Dict) -> Dict: + return _extract_capacity_hints_from_raw(raw) + + def get_model_engine_raw_url(model_engine_url: str) -> str: """ Extract the raw base URL from a ModelEngine URL by stripping any API paths. @@ -96,14 +104,16 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: continue if internal_type: - filtered_models.append({ + cleaned_model = { "id": model.get("id", ""), "model_type": internal_type, "model_tag": me_type, "max_tokens": DEFAULT_LLM_MAX_TOKENS if internal_type in ("llm", "vlm") else 0, "base_url": host, "api_key": api_key, - }) + } + cleaned_model.update(_extract_capacity_hints(model)) + filtered_models.append(cleaned_model) return filtered_models except Exception as e: diff --git a/backend/services/providers/silicon_provider.py b/backend/services/providers/silicon_provider.py index 1875b3949..e078f83a7 100644 --- a/backend/services/providers/silicon_provider.py +++ b/backend/services/providers/silicon_provider.py @@ -4,7 +4,11 @@ from consts.const import DEFAULT_LLM_MAX_TOKENS from consts.provider import SILICON_GET_URL -from services.providers.base import AbstractModelProvider, _classify_provider_error +from services.providers.base import ( + AbstractModelProvider, + _classify_provider_error, + _extract_capacity_hints_from_raw, +) SILICON_VLM_MODEL_KEYWORDS = ( @@ -33,6 +37,10 @@ SILICON_VLM_METADATA_KEYWORDS = ("image", "video", "vision", "visual") +def _extract_capacity_hints(raw: Dict) -> Dict: + return _extract_capacity_hints_from_raw(raw) + + def _contains_silicon_vlm_metadata(value) -> bool: if isinstance(value, str): lower_value = value.lower() @@ -107,6 +115,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: # Annotate models with canonical fields expected downstream if provider_model_type in ("llm", "vlm"): for item in model_list: + item.update(_extract_capacity_hints(item)) item["model_tag"] = "chat" item["model_type"] = model_type item["max_tokens"] = DEFAULT_LLM_MAX_TOKENS diff --git a/backend/services/providers/tokenpony_provider.py b/backend/services/providers/tokenpony_provider.py index be2bb9c71..16adf0008 100644 --- a/backend/services/providers/tokenpony_provider.py +++ b/backend/services/providers/tokenpony_provider.py @@ -6,7 +6,11 @@ from consts.const import DEFAULT_LLM_MAX_TOKENS from consts.provider import TOKENPONY_GET_URL -from services.providers.base import AbstractModelProvider, _classify_provider_error +from services.providers.base import ( + AbstractModelProvider, + _classify_provider_error, + _extract_capacity_hints_from_raw, +) TOKENPONY_IMAGE_UNDERSTANDING_KEYWORDS = ( @@ -41,6 +45,10 @@ TOKENPONY_VIDEO_UNDERSTANDING_KEYWORDS = ("omni", "video") +def _extract_capacity_hints(raw: Dict) -> Dict: + return _extract_capacity_hints_from_raw(raw) + + def _has_keyword(text: str, keywords: tuple) -> bool: return any(keyword in text for keyword in keywords) @@ -126,6 +134,7 @@ async def get_models(self, provider_config: Dict) -> List[Dict]: "model_type": "", "max_tokens": DEFAULT_LLM_MAX_TOKENS } + cleaned_model.update(_extract_capacity_hints(model_obj)) # 1. rerank if 'rerank' in m_id: cleaned_model.update({"model_tag": "rerank", "model_type": "rerank"}) diff --git a/backend/utils/config_utils.py b/backend/utils/config_utils.py index 3fe6f3621..2d1c5572b 100644 --- a/backend/utils/config_utils.py +++ b/backend/utils/config_utils.py @@ -2,6 +2,7 @@ import logging from typing import Dict, Any +from pydantic import ValidationError from sqlalchemy.sql import func from database.model_management_db import get_model_by_model_id @@ -16,6 +17,9 @@ logger = logging.getLogger("config_utils") +CONTEXT_SOFT_LIMIT_RATIO_KEY = "context.soft_limit_ratio" + + def safe_value(value): """Helper function for processing configuration values""" if value is None: @@ -112,6 +116,39 @@ def get_app_config(self, key: str, default="", tenant_id: str | None = None): return tenant_config[key] return default + def get_capacity_reserve_policy(self, tenant_id: str | None = None): + """Resolve W2 reserve policy from tenant config. + + Missing `context.soft_limit_ratio` uses the code default. Invalid + configured values fail closed so production requests do not silently use + a different compaction envelope than operators configured. + """ + from nexent.core.models.capacity_budget import ( + CapacityReservePolicy, + InvalidReservePolicy, + ) + + if tenant_id is None: + logger.warning("No tenant_id specified when getting capacity reserve policy") + return CapacityReservePolicy() + + tenant_config = self.load_config(tenant_id) + raw_ratio = tenant_config.get(CONTEXT_SOFT_LIMIT_RATIO_KEY) + if raw_ratio in (None, ""): + return CapacityReservePolicy() + + try: + ratio = float(str(raw_ratio).strip()) + return CapacityReservePolicy( + soft_limit_ratio=ratio, + soft_limit_ratio_source="tenant_config", + ) + except (TypeError, ValueError, ValidationError) as exc: + raise InvalidReservePolicy( + f"{CONTEXT_SOFT_LIMIT_RATIO_KEY} must be a decimal in (0, 1], " + f"got {raw_ratio!r}" + ) from exc + def set_single_config(self, user_id: str | None = None, tenant_id: str | None = None, key: str | None = None, value: str | None = None, ): """Set configuration value in database with caching""" diff --git a/docker/init.sql b/docker/init.sql index 5b0ff025b..ea89e5d10 100644 --- a/docker/init.sql +++ b/docker/init.sql @@ -179,6 +179,13 @@ CREATE TABLE IF NOT EXISTS "model_record_t" ( "access_token" varchar(100) COLLATE "pg_catalog"."default" DEFAULT '', "concurrency_limit" INTEGER DEFAULT NULL, "timeout_seconds" INTEGER DEFAULT 120, + "context_window_tokens" INTEGER DEFAULT NULL, + "max_input_tokens" INTEGER DEFAULT NULL, + "max_output_tokens" INTEGER DEFAULT NULL, + "default_output_reserve_tokens" INTEGER DEFAULT NULL, + "tokenizer_family" varchar(100) COLLATE "pg_catalog"."default" DEFAULT NULL, + "capacity_source" varchar(100) COLLATE "pg_catalog"."default" DEFAULT NULL, + "capability_profile_version" varchar(100) COLLATE "pg_catalog"."default" DEFAULT NULL, CONSTRAINT "nexent_models_t_pk" PRIMARY KEY ("model_id") ); ALTER TABLE "model_record_t" OWNER TO "root"; @@ -206,6 +213,13 @@ COMMENT ON COLUMN "model_record_t"."model_appid" IS 'Application ID for model au COMMENT ON COLUMN "model_record_t"."access_token" IS 'Access token for model authentication.'; COMMENT ON COLUMN "model_record_t"."concurrency_limit" IS 'Maximum concurrent requests for this model. Default is NULL (unlimited).'; COMMENT ON COLUMN "model_record_t"."timeout_seconds" IS 'Request timeout in seconds for this model. Default is 120 seconds.'; +COMMENT ON COLUMN "model_record_t"."context_window_tokens" IS 'Total combined input/output context window in tokens, when the provider uses a combined window. Nullable.'; +COMMENT ON COLUMN "model_record_t"."max_input_tokens" IS 'Provider hard input-token limit when distinct from the combined window. Nullable.'; +COMMENT ON COLUMN "model_record_t"."max_output_tokens" IS 'Provider-supported or operator-configured completion-output cap. Replaces the ambiguous LLM meaning of max_tokens. Nullable.'; +COMMENT ON COLUMN "model_record_t"."default_output_reserve_tokens" IS 'Default output allowance reserved per request before constructing input context. Nullable.'; +COMMENT ON COLUMN "model_record_t"."tokenizer_family" IS 'Token-counting strategy or provider/model tokenizer identifier mapped via tokenizer_registry. Nullable.'; +COMMENT ON COLUMN "model_record_t"."capacity_source" IS 'Source of the persisted capacity value. Optional values: operator, profile, provider_candidate, legacy, unknown.'; +COMMENT ON COLUMN "model_record_t"."capability_profile_version" IS 'Version of the approved provider/model capability profile used by the request, e.g. openai/gpt-4o@1.'; COMMENT ON TABLE "model_record_t" IS 'List of models defined by users in the configuration page'; INSERT INTO "nexent"."model_record_t" ("model_repo", "model_name", "model_factory", "model_type", "api_key", "base_url", "max_tokens", "used_token", "display_name", "connect_status") VALUES ('', 'volcano_tts', 'OpenAI-API-Compatible', 'tts', '', '', 0, 0, 'volcano_tts', 'unavailable'); @@ -339,6 +353,7 @@ CREATE TABLE IF NOT EXISTS nexent.ag_tenant_agent_t ( is_new BOOLEAN DEFAULT FALSE, provide_run_summary BOOLEAN DEFAULT FALSE, enable_context_manager BOOLEAN DEFAULT FALSE, + requested_output_tokens INTEGER NULL, verification_config JSONB, version_no INTEGER DEFAULT 0 NOT NULL, current_version_no INTEGER NULL, @@ -402,6 +417,7 @@ COMMENT ON COLUMN nexent.ag_tenant_agent_t.version_no IS 'Version number. 0 = dr COMMENT ON COLUMN nexent.ag_tenant_agent_t.current_version_no IS 'Current published version number. NULL means no version published yet'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.ingroup_permission IS 'In-group permission: EDIT, READ_ONLY, PRIVATE'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.enable_context_manager IS 'Whether to enable context management (compression) for this agent'; +COMMENT ON COLUMN nexent.ag_tenant_agent_t.requested_output_tokens IS 'Per-agent override for W2 requested_output_tokens. NULL means inherit the resolved model-level default. Must satisfy 0 < value <= max_output_tokens from the resolved W1 capacity at save time.'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.verification_config IS 'Layered ReAct self-verification configuration'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.greeting_message IS 'Agent greeting message displayed on chat initial screen'; COMMENT ON COLUMN nexent.ag_tenant_agent_t.example_questions IS 'List of example questions for starting a conversation with this agent'; @@ -1762,6 +1778,27 @@ CREATE TABLE IF NOT EXISTS nexent.model_monitoring_record_t ( input_tokens INT4, output_tokens INT4, total_tokens INT4, + context_window_tokens INT4, + default_output_reserve_tokens INT4, + capability_profile_version VARCHAR(100), + capacity_source VARCHAR(100), + requested_output_tokens INT4, + provider_input_limit_tokens INT4, + tokenizer_family VARCHAR(100), + counting_mode VARCHAR(20), + unknown_capabilities JSONB, + capacity_fingerprint VARCHAR(64), + budget_fingerprint VARCHAR(64), + budget_w1_fingerprint VARCHAR(64), + budget_requested_output_tokens INT4, + budget_output_reserve_source VARCHAR(32), + budget_provider_input_limit_tokens INT4, + budget_uncertainty_reserve_tokens INT4, + budget_uncertainty_reserve_basis VARCHAR(64), + budget_soft_limit_ratio FLOAT, + budget_soft_input_budget_tokens INT4, + budget_hard_input_budget_tokens INT4, + budget_warnings JSONB, generation_rate FLOAT, is_streaming BOOLEAN DEFAULT FALSE, is_success BOOLEAN DEFAULT TRUE, @@ -1792,6 +1829,27 @@ COMMENT ON COLUMN nexent.model_monitoring_record_t.ttft_ms IS 'Time to first tok COMMENT ON COLUMN nexent.model_monitoring_record_t.input_tokens IS 'Number of input prompt tokens'; COMMENT ON COLUMN nexent.model_monitoring_record_t.output_tokens IS 'Number of output completion tokens'; COMMENT ON COLUMN nexent.model_monitoring_record_t.total_tokens IS 'Total tokens (input + output)'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.context_window_tokens IS 'Resolved total combined model context window for this request'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.default_output_reserve_tokens IS 'Default output allowance reserved before input context construction'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.capability_profile_version IS 'Version of the resolved capacity profile for this request'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.capacity_source IS 'Dominant source of resolved capacity fields for this request'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.requested_output_tokens IS 'Output tokens requested or reserved during capacity resolution'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.provider_input_limit_tokens IS 'Resolved provider input-token limit used by context management'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.tokenizer_family IS 'Tokenizer family used for request token counting'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.counting_mode IS 'Token counting mode for the request: exact or estimated'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.unknown_capabilities IS 'Structured list of capacity capabilities unknown at resolution time'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.capacity_fingerprint IS 'Fingerprint of the resolved model capacity snapshot'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_fingerprint IS 'Fingerprint of the resolved W2 safe input budget snapshot'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_w1_fingerprint IS 'W1 capacity fingerprint consumed by the W2 budget snapshot'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_requested_output_tokens IS 'W2 trusted requested output tokens used at dispatch'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_output_reserve_source IS 'Source of the W2 requested output token reserve'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_provider_input_limit_tokens IS 'Provider input limit after applying the W2 output reserve'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_uncertainty_reserve_tokens IS 'Additional W2 uncertainty reserve deducted from input budget'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_uncertainty_reserve_basis IS 'Basis used for the W2 uncertainty reserve'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_soft_limit_ratio IS 'W2 soft input budget ratio'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_soft_input_budget_tokens IS 'W2 soft input budget where proactive compression begins'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_hard_input_budget_tokens IS 'W2 hard input budget consumed by W3 final fit'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_warnings IS 'Structured W2 budget warnings active for this request'; COMMENT ON COLUMN nexent.model_monitoring_record_t.generation_rate IS 'Token generation rate in tokens per second'; COMMENT ON COLUMN nexent.model_monitoring_record_t.is_streaming IS 'Whether the request used streaming response'; COMMENT ON COLUMN nexent.model_monitoring_record_t.is_success IS 'Whether the request completed successfully'; diff --git a/docker/sql/v2.2.0_0615_context_management_capacity_schema.sql b/docker/sql/v2.2.0_0615_context_management_capacity_schema.sql new file mode 100644 index 000000000..cc4194d96 --- /dev/null +++ b/docker/sql/v2.2.0_0615_context_management_capacity_schema.sql @@ -0,0 +1,144 @@ +-- Migration kind: REQUIRED_SCHEMA +-- Required for: all upgraded deployments before running W1/W2 context-management code. +-- Reason: new code reads/writes these model capacity, monitoring snapshot, and agent override columns. + +-- ============================================================ +-- W1: Add explicit model token-capacity fields to model_record_t +-- ============================================================ +-- All columns are nullable and additive; legacy max_tokens stays as a deprecated +-- output-cap alias until consumers migrate. + +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS context_window_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS max_input_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS max_output_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS default_output_reserve_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS tokenizer_family VARCHAR(100) DEFAULT NULL; + +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS capacity_source VARCHAR(100) DEFAULT NULL; + +ALTER TABLE nexent.model_record_t +ADD COLUMN IF NOT EXISTS capability_profile_version VARCHAR(100) DEFAULT NULL; + +COMMENT ON COLUMN nexent.model_record_t.context_window_tokens IS 'Total combined input/output context window in tokens, when the provider uses a combined window. Nullable.'; +COMMENT ON COLUMN nexent.model_record_t.max_input_tokens IS 'Provider hard input-token limit when distinct from the combined window. Nullable.'; +COMMENT ON COLUMN nexent.model_record_t.max_output_tokens IS 'Provider-supported or operator-configured completion-output cap. Replaces the ambiguous LLM meaning of max_tokens. Nullable.'; +COMMENT ON COLUMN nexent.model_record_t.default_output_reserve_tokens IS 'Default output allowance reserved per request before constructing input context. Nullable.'; +COMMENT ON COLUMN nexent.model_record_t.tokenizer_family IS 'Token-counting strategy or provider/model tokenizer identifier mapped via tokenizer_registry. Nullable.'; +COMMENT ON COLUMN nexent.model_record_t.capacity_source IS 'Source of the persisted capacity value. Optional values: operator, profile, provider_candidate, legacy, unknown.'; +COMMENT ON COLUMN nexent.model_record_t.capability_profile_version IS 'Version of the approved provider/model capability profile used by the request, e.g. openai/gpt-4o@1.'; + +-- ============================================================ +-- W1: Persist resolved model capacity snapshot fields on monitoring records +-- ============================================================ + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS context_window_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS default_output_reserve_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS capability_profile_version VARCHAR(100) DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS capacity_source VARCHAR(100) DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS requested_output_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS provider_input_limit_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS tokenizer_family VARCHAR(100) DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS counting_mode VARCHAR(20) DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS unknown_capabilities JSONB DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS capacity_fingerprint VARCHAR(64) DEFAULT NULL; + +COMMENT ON COLUMN nexent.model_monitoring_record_t.context_window_tokens IS 'Resolved total combined model context window for this request'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.default_output_reserve_tokens IS 'Default output allowance reserved before input context construction'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.capability_profile_version IS 'Version of the resolved capacity profile for this request'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.capacity_source IS 'Dominant source of resolved capacity fields for this request'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.requested_output_tokens IS 'Output tokens requested or reserved during capacity resolution'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.provider_input_limit_tokens IS 'Resolved provider input-token limit used by context management'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.tokenizer_family IS 'Tokenizer family used for request token counting'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.counting_mode IS 'Token counting mode for the request: exact or estimated'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.unknown_capabilities IS 'Structured list of capacity capabilities unknown at resolution time'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.capacity_fingerprint IS 'Fingerprint of the resolved model capacity snapshot'; + +-- ============================================================ +-- W2: Add per-agent requested_output_tokens override +-- ============================================================ + +ALTER TABLE nexent.ag_tenant_agent_t + ADD COLUMN IF NOT EXISTS requested_output_tokens INTEGER NULL; + +COMMENT ON COLUMN nexent.ag_tenant_agent_t.requested_output_tokens IS + 'Per-agent override for W2 requested_output_tokens. NULL means inherit ' + 'the resolved model-level default. Must satisfy 0 < value <= ' + 'max_output_tokens from the resolved W1 capacity at save time.'; + +-- ============================================================ +-- W2: Add safe input budget snapshot fields to model monitoring records +-- ============================================================ + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_fingerprint VARCHAR(64) DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_w1_fingerprint VARCHAR(64) DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_requested_output_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_output_reserve_source VARCHAR(32) DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_provider_input_limit_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_uncertainty_reserve_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_uncertainty_reserve_basis VARCHAR(64) DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_soft_limit_ratio FLOAT DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_soft_input_budget_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_hard_input_budget_tokens INTEGER DEFAULT NULL; + +ALTER TABLE nexent.model_monitoring_record_t +ADD COLUMN IF NOT EXISTS budget_warnings JSONB DEFAULT NULL; + +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_fingerprint IS 'Fingerprint of the resolved W2 safe input budget snapshot'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_w1_fingerprint IS 'W1 capacity fingerprint consumed by the W2 budget snapshot'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_requested_output_tokens IS 'W2 trusted requested output tokens used at dispatch'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_output_reserve_source IS 'Source of the W2 requested output token reserve'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_provider_input_limit_tokens IS 'Provider input limit after applying the W2 output reserve'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_uncertainty_reserve_tokens IS 'Additional W2 uncertainty reserve deducted from input budget'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_uncertainty_reserve_basis IS 'Basis used for the W2 uncertainty reserve'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_soft_limit_ratio IS 'W2 soft input budget ratio'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_soft_input_budget_tokens IS 'W2 soft input budget where proactive compression begins'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_hard_input_budget_tokens IS 'W2 hard input budget consumed by W3 final fit'; +COMMENT ON COLUMN nexent.model_monitoring_record_t.budget_warnings IS 'Structured W2 budget warnings active for this request'; diff --git a/docker/sql/v2.2.0_0617_context_management_capacity_data_fix.sql b/docker/sql/v2.2.0_0617_context_management_capacity_data_fix.sql new file mode 100644 index 000000000..21a794e18 --- /dev/null +++ b/docker/sql/v2.2.0_0617_context_management_capacity_data_fix.sql @@ -0,0 +1,138 @@ +-- Migration kind: RECOMMENDED_DATA_FIX +-- Required for: upgraded deployments with existing model_record_t rows. +-- Safe to skip when: fresh deployment, or operators will manually fill capacity fields. +-- Reason: improves legacy model capacity completeness and reconciles the temporary max_tokens alias. + +-- ============================================================ +-- Backfill capacity columns on legacy model_record_t rows +-- ============================================================ +-- Matches (model_factory, model_name) against W1 day-one catalog entries. +-- Idempotent: only writes when context_window_tokens IS NULL, so re-running on +-- already-backfilled rows is a no-op. +-- +-- Catalog source of truth: backend/consts/capability_profiles.py (W1 ADR +-- Decision 1). If the catalog is bumped, mirror the change here in a new +-- migration; do not edit this file in place after it has been released. +-- +-- Coverage caveat: rows whose model_factory does not match a catalog provider +-- key (commonly the manual-add default 'OpenAI-API-Compatible' per CM-031) +-- will not be backfilled by this migration. Operators must either update +-- model_factory directly, re-save the model through the W1-aware UI, or wait +-- for W17. Startup logs surface the residual count. + +DO $$ +DECLARE + v_updated INTEGER := 0; + v_total INTEGER := 0; +BEGIN + -- openai/gpt-4o + UPDATE nexent.model_record_t + SET context_window_tokens = 128000, + max_output_tokens = 16384, + default_output_reserve_tokens = 4096 + WHERE LOWER(model_factory) = 'openai' + AND model_name = 'gpt-4o' + AND delete_flag = 'N' + AND context_window_tokens IS NULL; + GET DIAGNOSTICS v_updated = ROW_COUNT; + v_total := v_total + v_updated; + + -- openai/gpt-4.1 + UPDATE nexent.model_record_t + SET context_window_tokens = 1000000, + max_output_tokens = 32768, + default_output_reserve_tokens = 8192 + WHERE LOWER(model_factory) = 'openai' + AND model_name = 'gpt-4.1' + AND delete_flag = 'N' + AND context_window_tokens IS NULL; + GET DIAGNOSTICS v_updated = ROW_COUNT; + v_total := v_total + v_updated; + + -- dashscope/qwen-plus + UPDATE nexent.model_record_t + SET context_window_tokens = 131072, + max_output_tokens = 16384, + default_output_reserve_tokens = 4096 + WHERE LOWER(model_factory) = 'dashscope' + AND model_name = 'qwen-plus' + AND delete_flag = 'N' + AND context_window_tokens IS NULL; + GET DIAGNOSTICS v_updated = ROW_COUNT; + v_total := v_total + v_updated; + + -- dashscope/qwen-turbo + UPDATE nexent.model_record_t + SET context_window_tokens = 1000000, + max_output_tokens = 16384, + default_output_reserve_tokens = 4096 + WHERE LOWER(model_factory) = 'dashscope' + AND model_name = 'qwen-turbo' + AND delete_flag = 'N' + AND context_window_tokens IS NULL; + GET DIAGNOSTICS v_updated = ROW_COUNT; + v_total := v_total + v_updated; + + -- dashscope/glm-5.1 + UPDATE nexent.model_record_t + SET context_window_tokens = 200000, + max_output_tokens = 131072, + default_output_reserve_tokens = 8192 + WHERE LOWER(model_factory) = 'dashscope' + AND model_name = 'glm-5.1' + AND delete_flag = 'N' + AND context_window_tokens IS NULL; + GET DIAGNOSTICS v_updated = ROW_COUNT; + v_total := v_total + v_updated; + + -- silicon/Qwen/Qwen3.6-27B + UPDATE nexent.model_record_t + SET context_window_tokens = 262144, + max_output_tokens = 65536, + default_output_reserve_tokens = 8192 + WHERE LOWER(model_factory) = 'silicon' + AND model_name = 'Qwen/Qwen3.6-27B' + AND delete_flag = 'N' + AND context_window_tokens IS NULL; + GET DIAGNOSTICS v_updated = ROW_COUNT; + v_total := v_total + v_updated; + + -- silicon/Pro/moonshotai/Kimi-K2.6 + UPDATE nexent.model_record_t + SET context_window_tokens = 262144, + max_output_tokens = 131072, + default_output_reserve_tokens = 8192 + WHERE LOWER(model_factory) = 'silicon' + AND model_name = 'Pro/moonshotai/Kimi-K2.6' + AND delete_flag = 'N' + AND context_window_tokens IS NULL; + GET DIAGNOSTICS v_updated = ROW_COUNT; + v_total := v_total + v_updated; + + RAISE NOTICE 'W2 catalog backfill: % row(s) updated', v_total; +END $$; + +-- ============================================================ +-- Reconcile the legacy max_tokens column with max_output_tokens +-- ============================================================ +-- Runs after the catalog backfill above because the backfill writes +-- max_output_tokens. Scope and safety: +-- * Only touches rows where max_output_tokens IS NOT NULL. +-- * Skips embedding rows because they reuse max_tokens as the vector dimension. +-- * Only updates rows where the two columns actually disagree. +-- * delete_flag = 'N' so soft-deleted rows are left alone. + +DO $$ +DECLARE + v_updated INTEGER := 0; +BEGIN + UPDATE nexent.model_record_t + SET max_tokens = max_output_tokens + WHERE delete_flag = 'N' + AND max_output_tokens IS NOT NULL + AND COALESCE(max_tokens, -1) <> max_output_tokens + AND COALESCE(model_type, '') NOT IN ('embedding', 'multi_embedding'); + + GET DIAGNOSTICS v_updated = ROW_COUNT; + RAISE NOTICE 'max_tokens alias reconcile: % row(s) updated', v_updated; +END $$; diff --git a/docker/sql/v2.2.2_0622_update_left_nav_menu.sql b/docker/sql/v2.2.2_0622_update_left_nav_menu.sql index 2de41f987..a2d841ab1 100644 --- a/docker/sql/v2.2.2_0622_update_left_nav_menu.sql +++ b/docker/sql/v2.2.2_0622_update_left_nav_menu.sql @@ -7,7 +7,7 @@ DELETE FROM nexent.role_permission_t WHERE permission_category = 'VISIBILITY' AND permission_type = 'LEFT_NAV_MENU'; -ALTER TABLE role_permission_t +ALTER TABLE nexent.role_permission_t ADD COLUMN IF NOT EXISTS parent_key VARCHAR(50); -- ============================================================ -- New Menu Structure: @@ -98,4 +98,4 @@ INSERT INTO nexent.role_permission_t (role_permission_id, user_role, permission_ INSERT INTO nexent.role_permission_t (role_permission_id, user_role, permission_category, permission_type, permission_subtype, parent_key) VALUES (1509, 'ASSET_OWNER', 'VISIBILITY', 'LEFT_NAV_MENU', '/agent-space', '/resource-space'), (1510, 'ASSET_OWNER', 'VISIBILITY', 'LEFT_NAV_MENU', '/mcp-space', '/resource-space'), -(1511, 'ASSET_OWNER', 'VISIBILITY', 'LEFT_NAV_MENU', '/skill-space', '/resource-space'); \ No newline at end of file +(1511, 'ASSET_OWNER', 'VISIBILITY', 'LEFT_NAV_MENU', '/skill-space', '/resource-space'); diff --git a/frontend/app/[locale]/agents/components/AgentSelectorHeader.tsx b/frontend/app/[locale]/agents/components/AgentSelectorHeader.tsx index 7f23f6ddc..2973578b8 100644 --- a/frontend/app/[locale]/agents/components/AgentSelectorHeader.tsx +++ b/frontend/app/[locale]/agents/components/AgentSelectorHeader.tsx @@ -271,6 +271,7 @@ export default function AgentSelectorHeader({ model_name: detail.model, model_id: detail.model_id ?? undefined, max_steps: detail.max_step, + requested_output_tokens: detail.requested_output_tokens ?? null, provide_run_summary: detail.provide_run_summary, enabled: detail.enabled, business_description: detail.business_description, diff --git a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx index cd46d2aa3..e07204cab 100644 --- a/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx +++ b/frontend/app/[locale]/agents/components/agentInfo/AgentGenerateDetail.tsx @@ -154,6 +154,15 @@ export default function AgentGenerateDetail({}) { })); }, [filteredGroups]); + const selectedMainAgentModel = useMemo(() => { + return availableLlmModels.find( + (model) => + model.id === editedAgent.model_id || + model.displayName === editedAgent.model || + model.name === editedAgent.model + ); + }, [availableLlmModels, editedAgent.model, editedAgent.model_id]); + // Initialize form values when currentAgentId changes or forceRefreshKey updates // Cached generation data is already merged into editedAgent by setCurrentAgent useEffect(() => { @@ -164,6 +173,7 @@ export default function AgentGenerateDetail({}) { mainAgentModel: editedAgent.model, mainAgentModelId: editedAgent.model_id, mainAgentMaxStep: editedAgent.max_step || 15, + requestedOutputTokens: editedAgent.requested_output_tokens ?? null, agentDescription: editedAgent.description || "", group_ids: normalizeNumberArray(editedAgent.group_ids || []), ingroup_permission: editedAgent.ingroup_permission || "READ_ONLY", @@ -182,6 +192,15 @@ export default function AgentGenerateDetail({}) { }, [form, currentAgentId, editedAgent, isCreatingMode, defaultLlmModel, accessibleGroupIds, forceRefreshKey]); + // Re-validate requested output tokens when the selected model's max changes, + // so switching to a model with a lower cap surfaces the violation immediately + // instead of waiting until save. + useEffect(() => { + if (form.getFieldValue("requestedOutputTokens") != null) { + form.validateFields(["requestedOutputTokens"]).catch(() => {}); + } + }, [form, selectedMainAgentModel?.maxOutputTokens]); + // Handle business description change const handleBusinessDescriptionChange = (value: string) => { @@ -954,6 +973,53 @@ export default function AgentGenerateDetail({}) { + + + + { + updateAgentConfig({ + requested_output_tokens: + typeof value === "number" ? value : null, + }); + }} + /> + + + + @@ -271,6 +294,14 @@ export const ModelAddDialog = ({ const [form, setForm] = useState(DEFAULT_FORM_STATE); const [loading, setLoading] = useState(false); const [verifyingConnectivity, setVerifyingConnectivity] = useState(false); + const [checkingCapacitySuggestion, setCheckingCapacitySuggestion] = + useState(false); + const [capacitySuggestionEnabled, setCapacitySuggestionEnabled] = + useState(true); + const [capacitySuggestion, setCapacitySuggestion] = + useState(null); + const [acceptedCapacitySuggestion, setAcceptedCapacitySuggestion] = + useState(null); const [connectivityStatus, setConnectivityStatus] = useState<{ status: ConnectivityStatusType; message: string; @@ -299,6 +330,11 @@ export const ModelAddDialog = ({ const [selectedModelForSettings, setSelectedModelForSettings] = useState(null); const [modelMaxTokens, setModelMaxTokens] = useState(""); + // Per-row capacity overrides edited via the gear icon in batch mode. Mirrors + // the top-level form's capacity fields so the same ModelCapacityFields panel + // can be rendered against this row-scoped state. + const [modelCapacity, setModelCapacity] = + useState(emptyCapacityForm); // Use the silicon model list hook const siliconHook = useSiliconModelList({ @@ -340,6 +376,9 @@ export const ModelAddDialog = ({ const resetForm = useCallback(() => { setForm(DEFAULT_FORM_STATE); setConnectivityStatus({ status: null, message: "" }); + setCapacitySuggestionEnabled(true); + setCapacitySuggestion(null); + setAcceptedCapacitySuggestion(null); setModelList([]); setModelSearchTerm(""); setSelectedModelIds(new Set()); @@ -437,12 +476,22 @@ export const ModelAddDialog = ({ })); // If the key configuration item changes, clear the verification status if ( - ["type", "url", "apiKey", "maxTokens", "vectorDimension"].includes( - field - ) || + [ + "type", + "name", + "url", + "apiKey", + "maxTokens", + "vectorDimension", + "provider", + ].includes(field) || field === "provider" ) { setConnectivityStatus({ status: null, message: "" }); + if (["type", "name", "url", "apiKey", "provider"].includes(field)) { + setCapacitySuggestion(null); + setAcceptedCapacitySuggestion(null); + } } // Clear model search term when model type changes if (field === "type") { @@ -455,6 +504,60 @@ export const ModelAddDialog = ({ } }; + const canSuggestCapacity = () => + supportsCapacityFields && + !form.isBatchImport && + form.name.trim() !== "" && + (form.url.trim() !== "" || form.provider.trim() !== ""); + + const applyCapacitySuggestion = (suggestion: CapacitySuggestion | null) => { + const next = capacityFormFromSuggestion(suggestion); + if (!next || Object.keys(next).length === 0) return; + setForm((prev) => ({ + ...prev, + ...next, + name: suggestion?.canonicalModelName || prev.name, + // Do NOT overwrite `provider` from the catalog suggestion. The catalog's + // `suggested_provider` namespace (deepseek, openai, jina, ...) is a + // superset of the frontend dropdown's allowed values + // (modelengine / silicon / dashscope / tokenpony / custom); writing an + // unknown one back into `model_factory` makes the model disappear from + // the active list and the edit dropdown. + })); + setAcceptedCapacitySuggestion(suggestion); + }; + + const handleSuggestCapacity = async () => { + if (!canSuggestCapacity()) { + message.warning(t("model.dialog.capacity.suggestion.missingInput")); + return; + } + setCheckingCapacitySuggestion(true); + try { + const suggestion = await modelService.suggestCapacity({ + modelName: form.name.trim(), + baseUrl: form.url.trim(), + // Only send providerHint when the user actually picked it (batch mode + // exposes the dropdown). In single-add mode the form keeps a hidden + // default ("modelengine") that the user never sees, so forwarding it + // would falsely pin catalog lookup to that provider. + ...(form.isBatchImport ? { providerHint: form.provider } : {}), + apiKey: form.apiKey.trim() || undefined, + modelType: resolveConnectivityModelType(form.type), + }); + setCapacitySuggestion(suggestion); + if (!suggestion.suggestions) { + setAcceptedCapacitySuggestion(null); + } + } catch (error) { + setCapacitySuggestion(null); + setAcceptedCapacitySuggestion(null); + message.error(t("model.dialog.capacity.suggestion.failed")); + } finally { + setCheckingCapacitySuggestion(false); + } + }; + // Verify if the vector dimension is valid const isValidVectorDimension = (value: string): boolean => { const dimension = Number.parseInt(value, 10); @@ -463,7 +566,19 @@ export const ModelAddDialog = ({ // Check if the form is valid const isFormValid = () => { + if ( + supportsCapacityFields && + // context_window/max_output are no longer required; only the data-shape + // checks (positive int / cross-field relationships) gate the Add button. + validateCapacityForm(form, []) + ) { + return false; + } + + // Capacity panel replaces the legacy max_tokens field for LLM/VLM types. + // Only voice and rerank-style types still rely on the standalone max_tokens. const needsMaxTokens = + !supportsCapacityFields && form.type !== MODEL_TYPES.EMBEDDING && form.type !== MODEL_TYPES.MULTI_EMBEDDING && form.type !== MODEL_TYPES.STT; @@ -472,6 +587,34 @@ export const ModelAddDialog = ({ if (needsMaxTokens && !isValidMaxTokens(form.maxTokens)) { return false; } + // Per-row capacity gate for LLM/VLM batch import. After moving + // context_window/max_output to optional-with-defaults, the batch top + // defaults are guaranteed to be populated (capacityFormToSnakePayload + // substitutes DEFAULT_* on empty), so `effectiveContextWindow` and + // `effectiveMaxOutput` cannot be falsy in normal flow. Keeping the + // gate as defense-in-depth for future row sources (e.g., a catalog + // entry that pre-fills both row columns NULL and somehow bypasses + // the substitute) -- cheap to keep, costly to discover missing. + // + // We deliberately do NOT fall back to model.max_tokens here. Per the + // W1/W2 production plan the legacy column is unconditionally seeded + // with DEFAULT_LLM_MAX_TOKENS (4096) by the provider adapters, so + // treating it as a stand-in for max_output_tokens would mask missing + // W2 metadata and let any row pass validation. + if (supportsCapacityFields) { + const batchDefaults = capacityFormToSnakePayload(form); + for (const model of modelList) { + if (!selectedModelIds.has(model.id)) continue; + if (!rowSupportsCapacityFields(model)) continue; + const effectiveContextWindow = + model.context_window_tokens ?? batchDefaults.context_window_tokens; + const effectiveMaxOutput = + model.max_output_tokens ?? batchDefaults.max_output_tokens; + if (!effectiveContextWindow || !effectiveMaxOutput) { + return false; + } + } + } // If provider is ModelEngine, require the ModelEngine URL as well. if (form.provider === "modelengine") { return ( @@ -519,11 +662,9 @@ export const ModelAddDialog = ({ return form.apiKey.trim() !== "" && form.name.trim() !== ""; } } - return ( - form.name.trim() !== "" && - form.url.trim() !== "" && - isValidMaxTokens(form.maxTokens) - ); + // LLM/VLM final case: capacity validation already enforced above; no + // standalone max_tokens to check. + return form.name.trim() !== "" && form.url.trim() !== ""; }; // Verify model connectivity @@ -596,15 +737,24 @@ export const ModelAddDialog = ({ connectivity = result.connectivity; } else { // For other model types (LLM, Embedding, VLM, Rerank, etc.) + // For LLM/VLM the legacy form.maxTokens field is gone; use the new + // capacity panel's maxOutputTokens value as the connectivity-probe + // budget. Do NOT fall back to form.maxTokens for capacity types -- + // the W1/W2 plan deprecates that field for LLM/VLM, and isFormValid + // already guarantees form.maxOutputTokens is filled before this + // probe runs. + const resolvedMaxTokens = + form.type === MODEL_TYPES.EMBEDDING + ? Number.parseInt(form.vectorDimension, 10) + : supportsCapacityFields + ? Number.parseInt(form.maxOutputTokens || "0", 10) + : parseMaxTokens(form.maxTokens); const config = { modelName: form.name, modelType: modelType, baseUrl: form.url, apiKey: form.apiKey.trim() || "sk-no-api-key", - maxTokens: - form.type === MODEL_TYPES.EMBEDDING - ? Number.parseInt(form.vectorDimension, 10) - : parseMaxTokens(form.maxTokens), + maxTokens: resolvedMaxTokens, embeddingDim: form.type === MODEL_TYPES.EMBEDDING ? Number.parseInt(form.vectorDimension, 10) @@ -613,6 +763,13 @@ export const ModelAddDialog = ({ const result = await modelService.verifyModelConfigConnectivity(config); connectivity = result.connectivity; + if ( + capacitySuggestionEnabled && + supportsCapacityFields && + result.capacitySuggestion + ) { + setCapacitySuggestion(result.capacitySuggestion); + } } // Set connectivity status @@ -672,6 +829,50 @@ export const ModelAddDialog = ({ }; }; + // Translate the top-level ModelCapacityFormState (camelCase, string) into the + // snake_case fields the batch-add backend expects. Used as the per-row + // fallback in batch mode when the row itself has no capacity overrides AND + // as the single-add wire payload. + // + // `applyDefaults` controls whether empty context_window/max_output get the + // shared UI defaults substituted. Defaults true for write-time paths + // (single-add, batch fallback for missing rows, per-row gear). The Settings + // Modal's "no-op edit" path passes false so that opening the gear and + // saving without touching anything does not clobber an existing + // `context_window_tokens=128000` (from catalog) with the 32K default. + const capacityFormToSnakePayload = ( + capacity: ModelCapacityFormState, + options?: { applyDefaults?: boolean } + ) => { + const applyDefaults = options?.applyDefaults !== false; + const toInt = (raw: string) => { + const trimmed = raw.trim(); + if (!/^[1-9]\d*$/.test(trimmed)) return undefined; + return Number.parseInt(trimmed, 10); + }; + const tokenizer = capacity.tokenizerFamily.trim(); + const contextWindow = + toInt(capacity.contextWindowTokens) ?? + (applyDefaults ? DEFAULT_CONTEXT_WINDOW_TOKENS : undefined); + const maxOutput = + toInt(capacity.maxOutputTokens) ?? + (applyDefaults ? DEFAULT_MAX_OUTPUT_TOKENS : undefined); + const hasAny = capacityFieldKeys.some( + (k) => capacity[k].trim() !== "" + ); + return { + context_window_tokens: contextWindow, + max_input_tokens: toInt(capacity.maxInputTokens), + max_output_tokens: maxOutput, + default_output_reserve_tokens: toInt(capacity.defaultOutputReserveTokens), + tokenizer_family: tokenizer || undefined, + // When defaults substituted, the row carries a deterministic operator + // value. When not (Settings Modal no-op preserve mode), only mark + // operator-sourced if the operator actually typed something. + capacity_source: applyDefaults || hasAny ? "operator" : undefined, + }; + }; + const buildBatchModelData = (model: any, modelType: ModelType) => { const isEmbeddingType = modelType === MODEL_TYPES.EMBEDDING || @@ -687,9 +888,41 @@ export const ModelAddDialog = ({ return modelWithoutMaxTokens; } + // Rerank and other legacy-only types: keep the pre-W2 path that relies on + // form.maxTokens as the batch default. + if (!rowSupportsCapacityFields(model)) { + return { + ...model, + max_tokens: model.max_tokens ?? parseMaxTokens(form.maxTokens), + }; + } + + // LLM/VLM: row-scoped capacity overrides win; otherwise fall back to the + // top-level capacity panel acting as the batch default. snake_case here + // because that's what the backend create-batch endpoint expects. + const fallback = capacityFormToSnakePayload(form); + + const resolved = { + context_window_tokens: + model.context_window_tokens ?? fallback.context_window_tokens, + max_input_tokens: model.max_input_tokens ?? fallback.max_input_tokens, + max_output_tokens: model.max_output_tokens ?? fallback.max_output_tokens, + default_output_reserve_tokens: + model.default_output_reserve_tokens ?? + fallback.default_output_reserve_tokens, + tokenizer_family: model.tokenizer_family ?? fallback.tokenizer_family, + capacity_source: model.capacity_source ?? fallback.capacity_source, + }; + return { ...model, - max_tokens: model.max_tokens ?? parseMaxTokens(form.maxTokens), + ...resolved, + // Mirror max_output_tokens into legacy max_tokens. Backend has a coercion + // helper but mirroring here keeps the wire payload self-consistent. + max_tokens: + resolved.max_output_tokens ?? + model.max_tokens ?? + parseMaxTokens(form.maxTokens), }; }; @@ -783,20 +1016,119 @@ export const ModelAddDialog = ({ } }; + // Resolve whether a fetched batch row uses the capacity panel. The row's own + // model_type wins (a row may be rerank even when form.type is LLM during + // mixed-type fetches), falling back to the form-level decision. + const rowSupportsCapacityFields = (model: any): boolean => { + const rowType = model?.model_type; + if ( + rowType === MODEL_TYPES.EMBEDDING || + rowType === MODEL_TYPES.MULTI_EMBEDDING + ) + return false; + if (rowType === MODEL_TYPES.STT || rowType === MODEL_TYPES.TTS) + return false; + if (rowType === MODEL_TYPES.RERANK) return false; + if (rowType) return true; + return supportsCapacityFields; + }; + // Handle settings button click const handleSettingsClick = (model: any) => { setSelectedModelForSettings(model); setModelMaxTokens(model.max_tokens?.toString() || ""); + if (rowSupportsCapacityFields(model)) { + // Merge order: row's W2 capacity values (from provider catalog hints) + // win, falling back to the top-level batch defaults typed into the + // capacity panel. The gear modal must reflect exactly what the row + // will end up using if the user clicks save without further edits. + // + // Crucially we do NOT pass model.max_tokens into capacityFormFromModel. + // Per the W1/W2 production plan, max_tokens is a deprecated legacy + // alias and "never used as a context window after migration". On + // batch-fetched rows the backend providers (Dashscope, Silicon, + // ModelEngine, TokenPony) unconditionally inject the legacy column + // with DEFAULT_LLM_MAX_TOKENS=4096 to keep the NOT-NULL contract; + // promoting that sentinel into max_output_tokens here makes the gear + // modal show 4096 every time the upstream catalog omits real W2 + // metadata, shadowing the user's batch defaults. + const rowMapped = capacityFormFromModel({ + contextWindowTokens: model.context_window_tokens, + maxInputTokens: model.max_input_tokens, + maxOutputTokens: model.max_output_tokens, + defaultOutputReserveTokens: model.default_output_reserve_tokens, + tokenizerFamily: model.tokenizer_family, + }); + setModelCapacity({ + contextWindowTokens: + rowMapped.contextWindowTokens || form.contextWindowTokens, + maxInputTokens: rowMapped.maxInputTokens || form.maxInputTokens, + maxOutputTokens: rowMapped.maxOutputTokens || form.maxOutputTokens, + defaultOutputReserveTokens: + rowMapped.defaultOutputReserveTokens || + form.defaultOutputReserveTokens, + tokenizerFamily: rowMapped.tokenizerFamily || form.tokenizerFamily, + }); + } else { + setModelCapacity(emptyCapacityForm); + } setSettingsModalVisible(true); }; // Handle settings save const handleSettingsSave = () => { - const nextMaxTokens = parseMaxTokens(modelMaxTokens); - if (!nextMaxTokens) return; + if (!selectedModelForSettings) { + setSettingsModalVisible(false); + return; + } - if (selectedModelForSettings) { - // Update the model in the list with new max_tokens + const useCapacity = rowSupportsCapacityFields(selectedModelForSettings); + + if (useCapacity) { + // Persist capacity fields onto the row in their snake_case API shape so + // buildBatchModelData can forward them without further translation. + // Defaults always apply at save: the gear modal preloads modelCapacity + // from the row's existing values (or batch defaults), so "no-op save" + // already carries non-empty inputs and goes through toInt unchanged. + // Only the row-NULL + empty-batch-default case lands DEFAULT_*, which + // is the desired "empty input means default" semantic. + const payload = capacityFormToSnakePayload(modelCapacity); + const hasAny = capacityFieldKeys.some( + (k) => modelCapacity[k].trim() !== "" + ); + setModelList((prev) => + prev.map((model) => + model.id === selectedModelForSettings.id + ? { + ...model, + context_window_tokens: + payload.context_window_tokens ?? + (hasAny ? null : model.context_window_tokens), + max_input_tokens: + payload.max_input_tokens ?? + (hasAny ? null : model.max_input_tokens), + max_output_tokens: + payload.max_output_tokens ?? + (hasAny ? null : model.max_output_tokens), + default_output_reserve_tokens: + payload.default_output_reserve_tokens ?? + (hasAny ? null : model.default_output_reserve_tokens), + tokenizer_family: + payload.tokenizer_family ?? + (hasAny ? null : model.tokenizer_family), + capacity_source: hasAny + ? payload.capacity_source + : model.capacity_source, + // Mirror max_output_tokens into legacy max_tokens so the + // backend coercion path stays consistent for rows that bypass it. + max_tokens: payload.max_output_tokens ?? model.max_tokens, + } + : model + ) + ); + } else { + const nextMaxTokens = parseMaxTokens(modelMaxTokens); + if (!nextMaxTokens) return; setModelList((prev) => prev.map((model) => model.id === selectedModelForSettings.id @@ -805,6 +1137,7 @@ export const ModelAddDialog = ({ ) ); } + setSettingsModalVisible(false); setSelectedModelForSettings(null); }; @@ -828,9 +1161,21 @@ export const ModelAddDialog = ({ form.type === MODEL_TYPES.EMBEDDING && form.isMultimodal ? (MODEL_TYPES.MULTI_EMBEDDING as ModelType) : form.type; - - // Determine the maximum tokens value - let maxTokensValue = parseMaxTokens(form.maxTokens) || 0; + const acceptedModelName = + acceptedCapacitySuggestion?.canonicalModelName || form.name; + // `acceptedCapacitySuggestion?.suggestedProvider` is intentionally NOT + // used here. See applyCapacitySuggestion above for the rationale. + + // Determine the maximum tokens value. + // For LLM/VLM (supportsCapacityFields), the legacy form.maxTokens + // input is hidden and must not be read here per the W1/W2 plan + // ("Never use legacy max_tokens"). Seed the legacy column with 0; + // buildCapacityPayload(form) spreads max_tokens := max_output_tokens + // a few lines below, keeping the deprecated NOT NULL column aligned + // with the W2 source of truth. + let maxTokensValue = supportsCapacityFields + ? 0 + : parseMaxTokens(form.maxTokens) || 0; if ( form.type === MODEL_TYPES.EMBEDDING || form.type === MODEL_TYPES.MULTI_EMBEDDING @@ -843,12 +1188,14 @@ export const ModelAddDialog = ({ if (tenantId) { const modelParams: any = { tenantId, - name: form.name, + name: acceptedModelName, type: modelType, url: form.url, apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, maxTokens: maxTokensValue, displayName: form.displayName || form.name, + modelFactory: form.provider, + ...(supportsCapacityFields ? buildCapacityPayload(form) : {}), }; // Add STT specific fields @@ -883,12 +1230,14 @@ export const ModelAddDialog = ({ await modelService.createManageTenantModel(modelParams); } else { const modelParams: any = { - name: form.name, + name: acceptedModelName, type: modelType, url: form.url, apiKey: form.apiKey.trim() === "" ? "sk-no-api-key" : form.apiKey, maxTokens: maxTokensValue, displayName: form.displayName || form.name, + modelFactory: form.provider, + ...(supportsCapacityFields ? buildCapacityPayload(form) : {}), }; // Add STT specific fields @@ -927,12 +1276,13 @@ export const ModelAddDialog = ({ // Note: id is set to 0 as placeholder; backend assigns the actual id when saving let modelConfig: SingleModelConfig | STTModelConfig | TTSModelConfig = { id: 0, - modelName: form.name, + modelName: acceptedModelName, displayName: form.displayName || form.name, apiConfig: { apiKey: form.apiKey, modelUrl: form.url, }, + ...(supportsCapacityFields ? buildCapacityPayload(form) : {}), }; // Add STT specific fields to config @@ -1036,6 +1386,18 @@ export const ModelAddDialog = ({ const isEmbeddingModel = form.type === MODEL_TYPES.EMBEDDING; const isSTTModel = form.type === MODEL_TYPES.STT; const isTTSModel = form.type === MODEL_TYPES.TTS; + // Capacity fields apply to LLM/VLM types in both single-add and batch-add + // paths. In batch mode the top-level capacity panel becomes a per-batch + // default (mirrors how form.maxTokens worked pre-W2), with each row's gear + // dialog free to override individual values. + const supportsCapacityFields = + !isEmbeddingModel && + !isSTTModel && + !isTTSModel && + form.type !== MODEL_TYPES.RERANK; + const capacityValidationError = supportsCapacityFields + ? validateCapacityForm(form, []) + : null; return ( )} - {/* Max Tokens */} - {!isEmbeddingModel && !isSTTModel && ( + {supportsCapacityFields && ( +
+ {form.isBatchImport && ( + + )} + {!form.isBatchImport && ( +
+
+
+ {t("model.dialog.capacity.suggestion.title")} +
+
+ {t("model.dialog.capacity.suggestion.hint")} +
+
+
+ + +
+
+ )} + handleFormChange(field, value)} + validationError={capacityValidationError} + formMode="add" + // context_window/max_output are no longer required; an empty + // input lands the shared DEFAULT_* values at save time + // (see capacityFormToSnakePayload). + suggestion={ + capacitySuggestionEnabled && !form.isBatchImport + ? capacitySuggestion + : null + } + suggestionLoading={checkingCapacitySuggestion} + onUseSuggestion={() => + applyCapacitySuggestion(capacitySuggestion) + } + /> +
+ )} + + {/* Max Tokens (legacy; only for non-LLM types still using the standalone field) */} + {!isEmbeddingModel && !isSTTModel && !supportsCapacityFields && (
+ ); + + // Both add and edit modes render as a flat panel. Required-field + // asterisks (context_window, max_output_tokens) must be unmissable, and + // hiding the controls behind a Collapse hides those asterisks. + return
{content}
; +}; diff --git a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx index c820cd5aa..48d54086c 100644 --- a/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx +++ b/frontend/app/[locale]/models/components/model/ModelDeleteDialog.tsx @@ -8,7 +8,12 @@ import { ExclamationCircleFilled } from "@ant-design/icons"; import { MODEL_TYPES, MODEL_SOURCES } from "@/const/modelConfig"; import { useConfig } from "@/hooks/useConfig"; import { modelService } from "@/services/modelService"; -import { ModelOption, ModelType, ModelSource } from "@/types/modelConfig"; +import { + CapacityCoverage, + ModelOption, + ModelType, + ModelSource, +} from "@/types/modelConfig"; import log from "@/lib/logger"; import { ModelEditDialog, ProviderConfigEditDialog } from "./ModelEditDialog"; @@ -23,6 +28,7 @@ interface ModelDeleteDialogProps { onClose: () => void; onSuccess: () => Promise; models: ModelOption[]; + capacityCoverage?: CapacityCoverage | null; } export const ModelDeleteDialog = ({ @@ -30,6 +36,7 @@ export const ModelDeleteDialog = ({ onClose, onSuccess, models, + capacityCoverage, }: ModelDeleteDialogProps) => { const { t } = useTranslation(); const { message } = App.useApp(); @@ -53,7 +60,8 @@ export const ModelDeleteDialog = ({ const [maxTokens, setMaxTokens] = useState(0); // Single model settings modal state - const [isSingleModelSettingsOpen, setIsSingleModelSettingsOpen] = useState(false); + const [isSingleModelSettingsOpen, setIsSingleModelSettingsOpen] = + useState(false); const [selectedSingleModel, setSelectedSingleModel] = useState(null); const [providerModelSearchTerm, setProviderModelSearchTerm] = useState(""); @@ -68,6 +76,22 @@ export const ModelDeleteDialog = ({ ]); const [chunkingBatchSize, setChunkingBatchSize] = useState("10"); const [savingEmbeddingConfig, setSavingEmbeddingConfig] = useState(false); + const bareCapacityModelIds = useMemo( + () => + new Set( + (capacityCoverage?.bareModels || []).map((model) => model.modelId) + ), + [capacityCoverage] + ); + const suggestionAvailableModelIds = useMemo( + () => + new Set( + (capacityCoverage?.bareModels || []) + .filter((model) => model.suggestionAvailable) + .map((model) => model.modelId) + ), + [capacityCoverage] + ); // Get model color scheme const getModelColorScheme = ( @@ -284,13 +308,9 @@ export const ModelDeleteDialog = ({ ); case MODEL_SOURCES.DASHSCOPE: - return ( - DashScope - ); + return DashScope; case MODEL_SOURCES.TOKENPONY: - return ( - TokenPony - ); + return TokenPony; case MODEL_SOURCES.VOLCENGINE: return ( VolcEngine @@ -326,7 +346,8 @@ export const ModelDeleteDialog = ({ if (bySilicon?.apiKey) return bySilicon.apiKey; const byModelEngine = models.find( - (m) => m.source === MODEL_SOURCES.MODELENGINE && m.type === type && m.apiKey + (m) => + m.source === MODEL_SOURCES.MODELENGINE && m.type === type && m.apiKey ); if (byModelEngine?.apiKey) return byModelEngine.apiKey; @@ -346,11 +367,14 @@ export const ModelDeleteDialog = ({ }; // Get provider base URL by model type (prefer ModelEngine entries) - const getProviderBaseUrlByType = (type: ModelType | null): string | undefined => { + const getProviderBaseUrlByType = ( + type: ModelType | null + ): string | undefined => { if (!type) return undefined; // Prefer provider entries (ModelEngine) first, then explicit modelConfig, then any model const engineModel = models.find( - (m) => m.source === MODEL_SOURCES.MODELENGINE && m.type === type && m.apiUrl + (m) => + m.source === MODEL_SOURCES.MODELENGINE && m.type === type && m.apiUrl ); if (engineModel?.apiUrl) return engineModel.apiUrl; @@ -477,7 +501,10 @@ export const ModelDeleteDialog = ({ }; // Handle model deletion - const handleDeleteModel = async (displayName: string, provider?: ModelSource) => { + const handleDeleteModel = async ( + displayName: string, + provider?: ModelSource + ) => { setDeletingModels((prev) => new Set(prev).add(displayName)); try { // Prefer explicit provider passed in, fall back to selectedSource @@ -622,17 +649,66 @@ export const ModelDeleteDialog = ({ }); }, [providerModels, providerModelSearchTerm]); - // Handle provider config save + // Per-row required capacity gate for the provider-management batch confirm. + // Unlike ModelAddDialog this dialog has no top-level "batch default capacity" + // panel, so each enabled row must itself carry positive context_window_tokens + // and max_output_tokens (set via the per-row gear modal). Without this gate + // the user could batch-confirm an LLM/VLM row whose catalog supplied no W2 + // metadata, persisting context_window_tokens=NULL, max_output_tokens=NULL, + // and only the backend's DEFAULT_LLM_MAX_TOKENS=4096 legacy sentinel -- the + // exact glm-5.2 production incident we just root-caused. + // + // We deliberately don't fall back to model.max_tokens here: per the W1/W2 + // plan the legacy column is unconditionally seeded by the provider + // adapters, so treating it as a stand-in would mask every missing W2 row. + const requiresW2Capacity = (modelType?: ModelType): boolean => { + if (!modelType) return false; + if ( + modelType === MODEL_TYPES.EMBEDDING || + modelType === MODEL_TYPES.MULTI_EMBEDDING + ) + return false; + if (modelType === MODEL_TYPES.STT || modelType === MODEL_TYPES.TTS) + return false; + if (modelType === MODEL_TYPES.RERANK) return false; + return true; + }; + const hasUnconfiguredSelectedRow = useMemo(() => { + if (!requiresW2Capacity(deletingModelType as ModelType)) return false; + return providerModels.some((m: any) => { + if (!pendingSelectedProviderIds.has(m.id)) return false; + return !m.context_window_tokens || !m.max_output_tokens; + }); + }, [providerModels, pendingSelectedProviderIds, deletingModelType]); + + // Handle provider config save. In addition to the shared API key / + // timeoutSeconds / concurrencyLimit, the "modify config" dialog now also + // exposes a top-level capacity panel (Tokenizer hidden) as a per-provider + // bulk-apply default, mirroring the batch-add UX. Any filled capacity + // field is forwarded to every model under (provider, model_type) so the + // user can fix glm-5.x style rows with NULL W2 columns from one place + // instead of opening N gear modals. const handleProviderConfigSave = async ({ apiKey, maxTokens, timeoutSeconds, concurrencyLimit, + contextWindowTokens, + maxInputTokens, + maxOutputTokens, + defaultOutputReserveTokens, + capacitySource, }: { apiKey?: string; maxTokens: number; timeoutSeconds?: number; concurrencyLimit?: number; + contextWindowTokens?: number; + maxInputTokens?: number; + maxOutputTokens?: number; + defaultOutputReserveTokens?: number; + tokenizerFamily?: string; + capacitySource?: string; }) => { setMaxTokens(maxTokens); if ( @@ -667,6 +743,17 @@ export const ModelDeleteDialog = ({ maxTokens: maxTokens || m.maxTokens, ...(timeoutSeconds !== undefined ? { timeoutSeconds } : {}), ...(concurrencyLimit !== undefined ? { concurrencyLimit } : {}), + // Only forward capacity fields the user actually filled in the + // bulk panel; omitted fields keep each model's existing value. + ...(contextWindowTokens !== undefined + ? { contextWindowTokens } + : {}), + ...(maxInputTokens !== undefined ? { maxInputTokens } : {}), + ...(maxOutputTokens !== undefined ? { maxOutputTokens } : {}), + ...(defaultOutputReserveTokens !== undefined + ? { defaultOutputReserveTokens } + : {}), + ...(capacitySource !== undefined ? { capacitySource } : {}), })); await modelService.updateBatchModel( @@ -677,13 +764,32 @@ export const ModelDeleteDialog = ({ // Show success message since no exception was thrown message.success(t("model.dialog.success.updateSuccess")); - // Synchronize providerModels state with the updated maxTokens + // Synchronize providerModels state with the bulk values that landed, + // so the row gear modals show the new defaults next time they open. setProviderModels((prev) => prev.map((model) => ({ ...model, max_tokens: maxTokens || model.max_tokens, timeout_seconds: timeoutSeconds || model.timeout_seconds, - concurrency_limit: concurrencyLimit !== undefined ? concurrencyLimit : model.concurrency_limit, + concurrency_limit: + concurrencyLimit !== undefined + ? concurrencyLimit + : model.concurrency_limit, + ...(contextWindowTokens !== undefined + ? { context_window_tokens: contextWindowTokens } + : {}), + ...(maxInputTokens !== undefined + ? { max_input_tokens: maxInputTokens } + : {}), + ...(maxOutputTokens !== undefined + ? { max_output_tokens: maxOutputTokens } + : {}), + ...(defaultOutputReserveTokens !== undefined + ? { default_output_reserve_tokens: defaultOutputReserveTokens } + : {}), + ...(capacitySource !== undefined + ? { capacity_source: capacitySource } + : {}), })) ); } catch (e) { @@ -770,7 +876,9 @@ export const ModelDeleteDialog = ({ selectedEmbeddingModel.apiKey || getApiKeyByType( deletingModelType, - (selectedEmbeddingModel?.source as ModelSource) || selectedSource || undefined + (selectedEmbeddingModel?.source as ModelSource) || + selectedSource || + undefined ); await modelService.updateSingleModel({ @@ -816,227 +924,274 @@ export const ModelDeleteDialog = ({ selectedSource && selectedSource !== MODEL_SOURCES.OPENAI_API_COMPATIBLE && deletingModelType && ( - + }} + > + {t("common.confirm")} + + ), ]} width={520} @@ -1319,6 +1474,12 @@ export const ModelDeleteDialog = ({ m.source === selectedSource ); const canEditEmbedding = isEmbeddingModel && existingModel; + const isBareCapacity = existingModel + ? bareCapacityModelIds.has(existingModel.id) + : false; + const hasSuggestion = existingModel + ? suggestionAvailableModelIds.has(existingModel.id) + : false; return (
)} + {isBareCapacity && ( + + + {t("model.dialog.capacityCoverage.tag")} + + + )}
{deletingModelType !== MODEL_TYPES.EMBEDDING && @@ -1357,7 +1533,43 @@ export const ModelDeleteDialog = ({ size="small" onClick={(e) => { e.stopPropagation(); // Prevent switch toggle - handleSingleModelSettingsClick(providerModel); + // The provider catalog entry carries snake_case + // ids and (sometimes) a default max_tokens, but + // never the user's saved capacity columns. When + // the model has already been added, overlay the + // saved ModelOption (camelCase) onto the catalog + // row in snake_case so the edit dialog + // pre-fills context_window_tokens etc. instead + // of showing empty fields. + const settingsTarget = existingModel + ? { + ...providerModel, + max_tokens: + existingModel.maxTokens ?? + providerModel.max_tokens, + timeout_seconds: + existingModel.timeoutSeconds ?? + providerModel.timeout_seconds, + concurrency_limit: + existingModel.concurrencyLimit ?? + providerModel.concurrency_limit, + context_window_tokens: + existingModel.contextWindowTokens, + max_input_tokens: + existingModel.maxInputTokens, + max_output_tokens: + existingModel.maxOutputTokens, + default_output_reserve_tokens: + existingModel.defaultOutputReserveTokens, + tokenizer_family: + existingModel.tokenizerFamily, + capacity_source: + existingModel.capacitySource, + capability_profile_version: + existingModel.capabilityProfileVersion, + } + : providerModel; + handleSingleModelSettingsClick(settingsTarget); }} /> @@ -1410,6 +1622,10 @@ export const ModelDeleteDialog = ({ selectedSource === MODEL_SOURCES.OPENAI_API_COMPATIBLE; const isClickable = isBatchImportedEmbedding || isCustomModelClickable; + const isBareCapacity = bareCapacityModelIds.has(model.id); + const hasSuggestion = suggestionAvailableModelIds.has( + model.id + ); return (
{model.displayName || model.name} ({model.name})
+ {isBareCapacity && ( + + + {t("model.dialog.capacityCoverage.tag")} + + + )}
+ + + handleFormChange(field, value)} + validationError={capacityValidationError} + capacitySource={model.capacitySource} + capabilityProfileVersion={model.capabilityProfileVersion} + // context_window/max_output no longer required; empty input + // lands DEFAULT_* via buildCapacityPayload at save time. + suggestion={capacitySuggestionEnabled ? capacitySuggestion : null} + suggestionLoading={checkingCapacitySuggestion} + onUseSuggestion={() => + applyCapacitySuggestion(capacitySuggestion) + } + // Legacy max_tokens is now surfaced via the actionable + // legacyMaxTokensCandidate prompt (no more silent promote in + // capacityFormFromModel). Keep the plain deprecation banner + // fallback for the rare case where the record has neither + // column populated, so users still see the migration nudge. + showDeprecatedMaxTokensWarning={ + Boolean(model.maxTokens) && + !model.maxOutputTokens && + !form.maxOutputTokens + } + legacyMaxTokensCandidate={ + model.maxOutputTokens ? undefined : model.maxTokens + } + /> + + )} + + {/* maxTokens (legacy; only kept for types not covered by the capacity panel) */} + {!isEmbeddingModel && !isRerankModel && !supportsCapacityFields && (
)} @@ -470,7 +697,9 @@ export const ModelEditDialog = ({ type="number" min="1" value={form.concurrencyLimit} - onChange={(e) => handleFormChange("concurrencyLimit", e.target.value)} + onChange={(e) => + handleFormChange("concurrencyLimit", e.target.value) + } placeholder={t("model.dialog.placeholder.concurrencyLimit")} />
@@ -577,72 +806,199 @@ export const ModelEditDialog = ({ }; // New: provider config edit dialog (only apiKey and maxTokens) +interface ProviderConfigInitialCapacity { + contextWindowTokens?: number; + maxInputTokens?: number; + maxOutputTokens?: number; + /** Legacy alias passed through so capacityFormFromModel can auto-migrate it. */ + maxTokens?: number; + defaultOutputReserveTokens?: number; + tokenizerFamily?: string; + capacitySource?: string; + capabilityProfileVersion?: string; +} + interface ProviderConfigEditDialogProps { - isOpen: boolean - initialApiKey?: string - initialMaxTokens?: string - initialTimeoutSeconds?: string - initialConcurrencyLimit?: string - modelType?: ModelType - showApiKeyField?: boolean // Whether to show API Key field (default: true) - onClose: () => void - onSave: (config: { apiKey?: string; maxTokens: number; timeoutSeconds?: number; concurrencyLimit?: number }) => Promise | void + isOpen: boolean; + initialApiKey?: string; + initialMaxTokens?: string; + initialTimeoutSeconds?: string; + initialConcurrencyLimit?: string; + initialCapacity?: ProviderConfigInitialCapacity; + hideCapacityFields?: boolean; // Suppress capacity controls when caller is a provider-level batch (not per-model) + modelType?: ModelType; + showApiKeyField?: boolean; // Whether to show API Key field (default: true) + onClose: () => void; + onSave: (config: { + apiKey?: string; + maxTokens: number; + timeoutSeconds?: number; + concurrencyLimit?: number; + contextWindowTokens?: number; + maxInputTokens?: number; + maxOutputTokens?: number; + defaultOutputReserveTokens?: number; + tokenizerFamily?: string; + capacitySource?: string; + }) => Promise | void; } export const ProviderConfigEditDialog = ({ isOpen, - initialApiKey = '', - initialMaxTokens = '', - initialTimeoutSeconds = '120', - initialConcurrencyLimit = '', + initialApiKey = "", + initialMaxTokens = "", + initialTimeoutSeconds = "120", + initialConcurrencyLimit = "", + initialCapacity, + hideCapacityFields = false, modelType, showApiKeyField = true, onClose, onSave, }: ProviderConfigEditDialogProps) => { - const { t } = useTranslation() - const [apiKey, setApiKey] = useState(initialApiKey) - const [maxTokens, setMaxTokens] = useState(initialMaxTokens) - const [timeoutSeconds, setTimeoutSeconds] = useState(initialTimeoutSeconds) - const [concurrencyLimit, setConcurrencyLimit] = useState(initialConcurrencyLimit) - const [saving, setSaving] = useState(false) + const { t } = useTranslation(); + const [apiKey, setApiKey] = useState(initialApiKey); + const [maxTokens, setMaxTokens] = useState(initialMaxTokens); + const [timeoutSeconds, setTimeoutSeconds] = useState( + initialTimeoutSeconds + ); + const [concurrencyLimit, setConcurrencyLimit] = useState( + initialConcurrencyLimit + ); + const [capacityForm, setCapacityForm] = useState( + initialCapacity ? capacityFormFromModel(initialCapacity) : emptyCapacityForm + ); + const [saving, setSaving] = useState(false); useEffect(() => { - setApiKey(initialApiKey) - setMaxTokens(initialMaxTokens) - setTimeoutSeconds(initialTimeoutSeconds) - setConcurrencyLimit(initialConcurrencyLimit) - }, [initialApiKey, initialMaxTokens, initialTimeoutSeconds, initialConcurrencyLimit]) + setApiKey(initialApiKey); + setMaxTokens(initialMaxTokens); + setTimeoutSeconds(initialTimeoutSeconds); + setConcurrencyLimit(initialConcurrencyLimit); + setCapacityForm( + initialCapacity + ? capacityFormFromModel(initialCapacity) + : emptyCapacityForm + ); + }, [ + initialApiKey, + initialMaxTokens, + initialTimeoutSeconds, + initialConcurrencyLimit, + initialCapacity, + ]); + + const isEmbeddingModel = + modelType === MODEL_TYPES.EMBEDDING || + modelType === MODEL_TYPES.MULTI_EMBEDDING; + const isRerankModel = modelType === MODEL_TYPES.RERANK; + const isVoiceModel = + modelType === MODEL_TYPES.STT || modelType === MODEL_TYPES.TTS; + const isLlmOrVlm = !isEmbeddingModel && !isRerankModel && !isVoiceModel; + // Per-model capacity panel: shown when the dialog is editing a single + // model's W2 capacity (gear icon next to a row). + const supportsCapacityFields = !hideCapacityFields && isLlmOrVlm; + // Provider-level "bulk apply" capacity panel: shown when the dialog is + // editing shared provider settings (the "修改配置" button). Renders the + // same ModelCapacityFields panel; context_window / max_output / etc. are + // reasonable defaults to broadcast across N models. + const supportsBulkCapacity = hideCapacityFields && isLlmOrVlm; + // Only rerank and voice models legitimately need the deprecated max_tokens + // input. Per the W1/W2 plan, never surface legacy max_tokens for LLM/VLM + // regardless of the hideCapacityFields flag. + const needsLegacyMaxTokens = isRerankModel || isVoiceModel; + // Neither mode marks any field required: + // - per-row mode (supportsCapacityFields): context_window/max_output are + // optional and get DEFAULT_* substituted at save by buildCapacityPayload + // - bulk-apply mode (supportsBulkCapacity): optional broadcast -- "fill + // to override; leave empty to keep each row's current value" + const capacityRequiredFields: Array = []; + const capacityValidationError = + supportsCapacityFields || supportsBulkCapacity + ? validateCapacityForm(capacityForm, capacityRequiredFields) + : null; + + const handleCapacityChange = ( + field: keyof typeof capacityForm, + value: string + ) => { + setCapacityForm((prev) => ({ ...prev, [field]: value })); + }; const valid = () => { - const isEmbeddingModel = modelType === MODEL_TYPES.EMBEDDING || modelType === MODEL_TYPES.MULTI_EMBEDDING - return isEmbeddingModel || isValidMaxTokens(maxTokens) - } + if (supportsCapacityFields) { + // Per-model capacity edit: required fields enforced by + // validateCapacityForm. + return !capacityValidationError; + } + if (supportsBulkCapacity) { + // Provider-level bulk apply: capacity fields are optional ("fill to + // override; leave empty to keep current per-model value"). Only fail + // when a typed value is not a positive integer. + return !capacityValidationError; + } + if (needsLegacyMaxTokens) { + return isValidMaxTokens(maxTokens); + } + // Embedding shared config: the dialog only owns + // apiKey/timeoutSeconds/concurrencyLimit, so always valid. + return true; + }; const handleSave = async () => { - if (!valid()) return + if (!valid()) return; try { - setSaving(true) - const isEmbeddingModel = modelType === MODEL_TYPES.EMBEDDING || modelType === MODEL_TYPES.MULTI_EMBEDDING - const isRerankModel = modelType === MODEL_TYPES.RERANK + setSaving(true); + // Only rerank/voice models legitimately surface the legacy maxTokens + // input. In every other case the maxTokens state still carries the + // backend's DEFAULT_LLM_MAX_TOKENS sentinel from the row prefill, so + // reading it would either be a no-op (LLM/VLM with capacity panel: + // buildCapacityPayload's max_output_tokens mirror overrides) or + // actively wrong (LLM/VLM provider-level config: would force the + // 4096 sentinel onto every existing row). Sending 0 here makes + // handleProviderConfigSave's `maxTokens || m.maxTokens` fall back to + // each row's current value, preserving it. + const legacyMaxTokens = needsLegacyMaxTokens + ? parseMaxTokens(maxTokens) || 0 + : 0; await onSave({ - ...(showApiKeyField ? { apiKey: apiKey.trim() === '' ? 'sk-no-api-key' : apiKey } : {}), - maxTokens: parseMaxTokens(maxTokens) || 0, - ...(!isEmbeddingModel && !isRerankModel ? { timeoutSeconds: parseInt(timeoutSeconds) || 120 } : {}), - ...(!isEmbeddingModel && !isRerankModel ? { concurrencyLimit: concurrencyLimit ? parseInt(concurrencyLimit) : undefined } : {}), - }) - onClose() + ...(showApiKeyField + ? { apiKey: apiKey.trim() === "" ? "sk-no-api-key" : apiKey } + : {}), + maxTokens: legacyMaxTokens, + ...(!isEmbeddingModel && !isRerankModel + ? { timeoutSeconds: parseInt(timeoutSeconds) || 120 } + : {}), + ...(!isEmbeddingModel && !isRerankModel + ? { + concurrencyLimit: concurrencyLimit + ? parseInt(concurrencyLimit) + : undefined, + } + : {}), + // Both per-model and bulk-apply modes write capacity via + // buildCapacityPayload. Per-model (supportsCapacityFields) opts + // into default substitution: empty context_window/max_output land + // DEFAULT_CONTEXT_WINDOW_TOKENS / DEFAULT_MAX_OUTPUT_TOKENS at the + // wire. Bulk-apply (supportsBulkCapacity) passes applyDefaults=false + // so empty fields stay omitted ("don't broadcast this value"), and + // an apiKey-only bulk edit doesn't accidentally null out per-row + // capacity by writing 32K/4K across N rows. + ...(supportsCapacityFields + ? buildCapacityPayload(capacityForm) + : supportsBulkCapacity + ? buildCapacityPayload(capacityForm, { applyDefaults: false }) + : {}), + }); + onClose(); } finally { - setSaving(false) + setSaving(false); } - } - - const isEmbeddingModel = modelType === MODEL_TYPES.EMBEDDING || modelType === MODEL_TYPES.MULTI_EMBEDDING - const isRerankModel = modelType === MODEL_TYPES.RERANK + }; return ( - setApiKey(e.target.value)} visibilityToggle={false} /> + setApiKey(e.target.value)} + visibilityToggle={false} + /> +
+ )} + {supportsCapacityFields && ( + + )} + {supportsBulkCapacity && ( +
+ +
)} - {!isEmbeddingModel && ( + {/* Legacy max_tokens input — only rendered for model types that + legitimately still own this field (rerank, STT/TTS). LLM/VLM use + the capacity panel; if hideCapacityFields=true is set (provider- + level config edit) the dialog deliberately drops both the + capacity panel and the legacy input -- per the W1/W2 plan + ("Never use legacy max_tokens") capacity is set per-model from + the gear icon, not via a provider-level shared value. */} + {needsLegacyMaxTokens && (
)}
- - +
- ) -} + ); +}; diff --git a/frontend/app/[locale]/models/components/modelConfig.tsx b/frontend/app/[locale]/models/components/modelConfig.tsx index e2787aaa8..1ddaa9deb 100644 --- a/frontend/app/[locale]/models/components/modelConfig.tsx +++ b/frontend/app/[locale]/models/components/modelConfig.tsx @@ -8,7 +8,7 @@ import { } from "react"; import { useTranslation } from "react-i18next"; -import { Button, Card, Col, Row, Space, App } from "antd"; +import { Alert, Button, Card, Col, Row, Space, App } from "antd"; import { Plus, ShieldCheck, RefreshCw, PenLine } from "lucide-react"; import { @@ -19,7 +19,7 @@ import { } from "@/const/modelConfig"; import { useConfig } from "@/hooks/useConfig"; import { modelService } from "@/services/modelService"; -import { ModelOption, ModelType } from "@/types/modelConfig"; +import { CapacityCoverage, ModelOption, ModelType } from "@/types/modelConfig"; import log from "@/lib/logger"; import { ModelListCard } from "./model/ModelListCard"; @@ -57,9 +57,18 @@ const getModelData = (t: any) => ({ multimodal: { title: t("modelConfig.category.multimodal"), options: [ - { id: MODEL_TYPES.VLM, name: t("modelConfig.option.imageUnderstandingModel") }, - { id: MODEL_TYPES.VLM2, name: t("modelConfig.option.imageGenerationModel") }, - { id: MODEL_TYPES.VLM3, name: t("modelConfig.option.videoUnderstandingModel") }, + { + id: MODEL_TYPES.VLM, + name: t("modelConfig.option.imageUnderstandingModel"), + }, + { + id: MODEL_TYPES.VLM2, + name: t("modelConfig.option.imageGenerationModel"), + }, + { + id: MODEL_TYPES.VLM3, + name: t("modelConfig.option.videoUnderstandingModel"), + }, ], }, voice: { @@ -112,6 +121,8 @@ export const ModelConfigSection = forwardRef< useState(false); const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); const [isVerifying, setIsVerifying] = useState(false); + const [capacityCoverage, setCapacityCoverage] = + useState(null); // Error state management const [errorFields, setErrorFields] = useState<{ [key: string]: boolean }>({ @@ -250,10 +261,14 @@ export const ModelConfigSection = forwardRef< if (!modelConfig) return; try { - const allModels = await modelService.getAllModels(); + const [allModels, coverage] = await Promise.all([ + modelService.getAllModels(), + modelService.getCapacityCoverage(), + ]); // Update state with all models setModels(allModels); + setCapacityCoverage(coverage); // Load selected models from configuration and check if models still exist const llmMain = modelConfig.llm.displayName; @@ -475,7 +490,14 @@ export const ModelConfigSection = forwardRef< const hasStt = !!modelConfig.stt.modelName; hasSelectedModels = - hasLlmMain || hasEmbedding || hasReranker || hasVlm || hasVlm2 || hasVlm3 || hasTts || hasStt; + hasLlmMain || + hasEmbedding || + hasReranker || + hasVlm || + hasVlm2 || + hasVlm3 || + hasTts || + hasStt; if (hasSelectedModels) { currentSelectedModels.llm.main = modelConfig.llm.modelName; @@ -485,8 +507,10 @@ export const ModelConfigSection = forwardRef< modelConfig.multiEmbedding.modelName || ""; currentSelectedModels.reranker.reranker = modelConfig.rerank.modelName; currentSelectedModels.multimodal.vlm = modelConfig.vlm.modelName; - currentSelectedModels.multimodal.vlm2 = modelConfig.vlm2?.modelName || ""; - currentSelectedModels.multimodal.vlm3 = modelConfig.vlm3?.modelName || ""; + currentSelectedModels.multimodal.vlm2 = + modelConfig.vlm2?.modelName || ""; + currentSelectedModels.multimodal.vlm3 = + modelConfig.vlm3?.modelName || ""; currentSelectedModels.voice.tts = modelConfig.tts.modelName; currentSelectedModels.voice.stt = modelConfig.stt.modelName; } else { @@ -636,7 +660,10 @@ export const ModelConfigSection = forwardRef< throttleTimerRef.current = setTimeout(async () => { try { // Use modelService to verify model - const isConnected = await modelService.verifyCustomModel(displayName, modelType); + const isConnected = await modelService.verifyCustomModel( + displayName, + modelType + ); // Update model status updateModelStatus( @@ -954,6 +981,27 @@ export const ModelConfigSection = forwardRef< + {capacityCoverage && capacityCoverage.bareCount > 0 && ( + model.suggestionAvailable + ).length, + })} + action={ + + } + /> + )} +
diff --git a/frontend/components/common/tokenUsageIndicator.tsx b/frontend/components/common/tokenUsageIndicator.tsx index adde20fbf..b4a644ead 100644 --- a/frontend/components/common/tokenUsageIndicator.tsx +++ b/frontend/components/common/tokenUsageIndicator.tsx @@ -14,7 +14,10 @@ function formatNumber(n: number): string { } export function TokenUsageIndicator({ latestMetrics }: TokenUsageIndicatorProps) { - const DEFAULT_THRESHOLD = 32000; + // Matches backend _TOKEN_THRESHOLD_LEGACY_FALLBACK; shown only when the + // backend stream does not carry a real token_threshold (rare once W2 ships). + // Sized for the typical 32K-context band shared by most production LLMs. + const DEFAULT_THRESHOLD = 32768; const estimated_context_tokens = latestMetrics?.estimated_context_tokens ?? null; const token_threshold = latestMetrics?.token_threshold ?? null; diff --git a/frontend/hooks/agent/useSaveGuard.ts b/frontend/hooks/agent/useSaveGuard.ts index 2f644e0bc..5f748023f 100644 --- a/frontend/hooks/agent/useSaveGuard.ts +++ b/frontend/hooks/agent/useSaveGuard.ts @@ -134,6 +134,7 @@ export const useSaveGuard = () => { model_name: currentEditedAgent.model, model_id: currentEditedAgent.model_id ?? undefined, max_steps: currentEditedAgent.max_step, + requested_output_tokens: currentEditedAgent.requested_output_tokens ?? null, provide_run_summary: currentEditedAgent.provide_run_summary, verification_config: currentEditedAgent.verification_config, enabled: true, diff --git a/frontend/public/locales/en/common.json b/frontend/public/locales/en/common.json index 9487c5f33..e5c3e006e 100644 --- a/frontend/public/locales/en/common.json +++ b/frontend/public/locales/en/common.json @@ -344,6 +344,10 @@ "agent.author.hint": "Default: {{email}}", "agent.provideRunSummary": "Provide Run Summary", "agent.provideRunSummary.error": "Please select whether to provide run summary", + "agent.requestedOutputTokens": "Output Reserve", + "agent.requestedOutputTokens.error": "Output reserve must be a positive integer", + "agent.requestedOutputTokens.maxError": "Output reserve cannot exceed this model's max output tokens ({{max}})", + "agent.requestedOutputTokens.tooltip": "Maximum tokens the model can produce in one reply. The value is reserved from the model's context window for this response; the remainder is the input budget for the system prompt and conversation history. Larger value → longer replies but smaller input budget (context compression triggers earlier). Smaller value → more history preserved but replies may be truncated. Leave blank to use the model's default output reserve.", "agent.verification": "Self Verification", "agent.verification.error": "Please select whether to enable self verification", "agent.description": "Agent Description", @@ -830,6 +834,55 @@ "model.dialog.placeholder.maxTokens": "Enter maximum tokens", "model.dialog.settings.title": "Model Settings", "model.dialog.settings.label.maxTokens": "Max Tokens", + "model.dialog.capacity.title": "Optional Capacity Settings", + "model.dialog.capacity.description": "Override or confirm model capacity. Leaving this empty will not block adding the model.", + "model.dialog.capacity.emptySummary": "The provider did not return capacity candidates; you can leave this empty.", + "model.dialog.capacity.emptyHint": "The provider model list did not include capacity information for this model. You can add it now and fill these fields later if precise context control is needed.", + "model.dialog.capacity.contextWindowTokens": "Context Window", + "model.dialog.capacity.contextWindowTokens.tooltip": "Total combined input and output context window.", + "model.dialog.capacity.maxInputTokens": "Max Input Tokens", + "model.dialog.capacity.maxInputTokens.tooltip": "Hard input limit when it is distinct from the total context window.", + "model.dialog.capacity.maxOutputTokens": "Max Output Tokens", + "model.dialog.capacity.maxOutputTokens.tooltip": "Provider-supported completion output cap.", + "model.dialog.capacity.defaultOutputReserveTokens": "Output Reserve", + "model.dialog.capacity.defaultOutputReserveTokens.tooltip": "Default output allowance reserved before constructing request input.", + "model.dialog.capacity.error.positiveInteger": "Capacity numeric fields must be positive integers or empty.", + "model.dialog.capacity.error.outputExceedsWindow": "Max output tokens cannot exceed the context window.", + "model.dialog.capacity.error.inputExceedsWindow": "Max input tokens cannot exceed the context window (any excess is silently clipped, so please adjust the value directly).", + "model.dialog.capacity.error.reserveExceedsOutput": "Output reserve cannot exceed max output tokens.", + "model.dialog.capacity.error.requiredMissing": "Context window and max input tokens are required.", + "model.dialog.capacity.deprecatedMaxTokens": "max_tokens is deprecated; use max_output_tokens.", + "model.dialog.capacity.legacyMaxTokensDetected": "Detected legacy max_tokens = {{value}}. Apply it as max_output_tokens?", + "model.dialog.capacity.legacyMaxTokens.apply": "Apply", + "model.dialog.capacity.source.operator": "Operator", + "model.dialog.capacity.source.profile": "Profile", + "model.dialog.capacity.source.provider_candidate": "Provider Candidate", + "model.dialog.capacity.source.legacy": "Legacy", + "model.dialog.capacity.source.unknown": "Unknown", + "model.dialog.capacity.suggestion.title": "Capacity suggestion", + "model.dialog.capacity.suggestion.hint": "Check the approved catalog and apply the result only when you choose to use it.", + "model.dialog.capacity.suggestion.check": "Check", + "model.dialog.capacity.suggestion.use": "Use suggestion", + "model.dialog.capacity.suggestion.found": "Capacity suggestion found", + "model.dialog.capacity.suggestion.notFound": "No capacity suggestion found", + "model.dialog.capacity.suggestion.noExplanation": "No additional details.", + "model.dialog.capacity.suggestion.missingInput": "Enter a model name and URL before checking capacity suggestions.", + "model.dialog.capacity.suggestion.failed": "Failed to check capacity suggestions.", + "model.dialog.capacity.suggestion.match.catalog_exact": "Catalog exact", + "model.dialog.capacity.suggestion.match.catalog_fuzzy": "Catalog fuzzy", + "model.dialog.capacity.suggestion.match.provider_discovery": "Provider discovery", + "model.dialog.capacity.suggestion.match.none": "No match", + "model.dialog.capacity.suggestion.confidence.high": "High confidence", + "model.dialog.capacity.suggestion.confidence.medium": "Medium confidence", + "model.dialog.capacity.suggestion.confidence.low": "Low confidence", + "model.dialog.capacityCoverage.tag": "Missing capacity", + "model.dialog.capacityCoverage.warning": "This model is missing context window or max output tokens. Open edit settings to fill capacity.", + "model.dialog.capacityCoverage.warningWithSuggestion": "This model is missing capacity. A catalog suggestion may be available in the edit dialog.", + "model.dialog.capacity.batchDefault.title": "Batch default capacity", + "model.dialog.capacity.batchDefault.hint": "Values entered here apply as the default capacity for every LLM/VLM model in this batch import. Click the gear icon on a row to override a specific model.", + "model.dialog.batch.requireRowCapacity": "Some enabled rows are missing context window or max output tokens. Open the gear icon to fill them in before confirming.", + "model.dialog.capacity.bulkApply.title": "Bulk apply capacity (optional)", + "model.dialog.capacity.bulkApply.hint": "Values entered here are bulk-applied to every model of this type under the current provider as part of this Modify Config. Empty fields are skipped and keep each model's existing value. Tokenizer is intentionally omitted because it should not be uniform across models -- set it from the per-row gear icon instead.", "model.dialog.modelList.tooltip.settings": "Model Settings", "model.dialog.hint.multimodalEnabled": "Multimodal vector model can process both images and text", "model.dialog.hint.multimodalDisabled": "Text vector model only processes text", @@ -976,6 +1029,9 @@ "modelConfig.button.addCustomModel": "Add Model", "modelConfig.button.editCustomModel": "Edit or Delete Model", "modelConfig.button.checkConnectivity": "Check Model Connectivity", + "modelConfig.capacityCoverage.warning": "{{bareCount}} of {{total}} LLM/VLM models are missing capacity fields.", + "modelConfig.capacityCoverage.description": "{{suggestionCount}} model(s) may have catalog suggestions. Open Manage Models, then edit a marked model to repair it.", + "modelConfig.capacityCoverage.manage": "Manage", "modelConfig.button.sync": "Sync", "modelConfig.button.add": "Add", "modelConfig.button.edit": "Edit", diff --git a/frontend/public/locales/zh/common.json b/frontend/public/locales/zh/common.json index 4735f22c5..5ff929a67 100644 --- a/frontend/public/locales/zh/common.json +++ b/frontend/public/locales/zh/common.json @@ -346,6 +346,10 @@ "agent.author.hint": "默认:{{email}}", "agent.provideRunSummary": "提供运行摘要", "agent.provideRunSummary.error": "请选择是否提供运行摘要", + "agent.requestedOutputTokens": "输出预留", + "agent.requestedOutputTokens.error": "输出预留必须为正整数", + "agent.requestedOutputTokens.maxError": "输出预留不能超过该模型的最大输出 tokens({{max}})", + "agent.requestedOutputTokens.tooltip": "每次回复模型最多可输出的 token 数。该值从模型的上下文窗口中预留,作为本轮回答空间;剩余空间分配给输入(系统提示词 + 历史对话)。设大→回答更长但输入预算变小,更早触发上下文压缩;设小→历史保留更多但回答可能被截断。留空表示使用模型的默认输出预留值。", "agent.verification": "自验证", "agent.verification.error": "请选择是否启用自验证", "agent.description": "智能体描述", @@ -801,6 +805,55 @@ "model.dialog.placeholder.maxTokens": "请输入最大Token数", "model.dialog.settings.title": "模型设置", "model.dialog.settings.label.maxTokens": "最大Token数", + "model.dialog.capacity.title": "可选容量配置", + "model.dialog.capacity.description": "用于覆盖或确认模型容量;不填不会影响添加模型。", + "model.dialog.capacity.emptySummary": "供应商未返回容量候选值,可留空直接添加。", + "model.dialog.capacity.emptyHint": "当前供应商列表没有返回这个模型的容量信息。可以留空直接添加,后续需要精确上下文控制时再编辑补充。", + "model.dialog.capacity.contextWindowTokens": "上下文窗口", + "model.dialog.capacity.contextWindowTokens.tooltip": "输入和输出合计的上下文窗口上限。", + "model.dialog.capacity.maxInputTokens": "最大输入Token数", + "model.dialog.capacity.maxInputTokens.tooltip": "当输入上限不同于总窗口时填写。", + "model.dialog.capacity.maxOutputTokens": "最大输出Token数", + "model.dialog.capacity.maxOutputTokens.tooltip": "模型或供应商支持的输出上限。", + "model.dialog.capacity.defaultOutputReserveTokens": "输出预留Token数", + "model.dialog.capacity.defaultOutputReserveTokens.tooltip": "构造请求输入前默认预留的输出额度。", + "model.dialog.capacity.error.positiveInteger": "容量数字字段必须为空或正整数。", + "model.dialog.capacity.error.outputExceedsWindow": "最大输出Token数不能超过上下文窗口。", + "model.dialog.capacity.error.inputExceedsWindow": "最大输入Token数不能超过上下文窗口(超出部分会被自动忽略,请直接调整数值)。", + "model.dialog.capacity.error.reserveExceedsOutput": "输出预留Token数不能超过最大输出Token数。", + "model.dialog.capacity.error.requiredMissing": "上下文窗口和最大输入Token数为必填项。", + "model.dialog.capacity.deprecatedMaxTokens": "max_tokens 已废弃,请使用 max_output_tokens。", + "model.dialog.capacity.legacyMaxTokensDetected": "检测到旧的「最大Tokens数」为 {{value}},是否填入最大输出Token数?", + "model.dialog.capacity.legacyMaxTokens.apply": "应用", + "model.dialog.capacity.source.operator": "人工配置", + "model.dialog.capacity.source.profile": "能力档案", + "model.dialog.capacity.source.provider_candidate": "供应商候选", + "model.dialog.capacity.source.legacy": "旧字段", + "model.dialog.capacity.source.unknown": "未知", + "model.dialog.capacity.suggestion.title": "容量建议", + "model.dialog.capacity.suggestion.hint": "从已审核目录检查容量;只有点击使用后才会写入表单。", + "model.dialog.capacity.suggestion.check": "检查", + "model.dialog.capacity.suggestion.use": "使用建议", + "model.dialog.capacity.suggestion.found": "已找到容量建议", + "model.dialog.capacity.suggestion.notFound": "未找到容量建议", + "model.dialog.capacity.suggestion.noExplanation": "暂无更多说明。", + "model.dialog.capacity.suggestion.missingInput": "请先填写模型名称和 URL,再检查容量建议。", + "model.dialog.capacity.suggestion.failed": "检查容量建议失败。", + "model.dialog.capacity.suggestion.match.catalog_exact": "目录精确匹配", + "model.dialog.capacity.suggestion.match.catalog_fuzzy": "目录模糊匹配", + "model.dialog.capacity.suggestion.match.provider_discovery": "供应商发现", + "model.dialog.capacity.suggestion.match.none": "未匹配", + "model.dialog.capacity.suggestion.confidence.high": "高置信度", + "model.dialog.capacity.suggestion.confidence.medium": "中置信度", + "model.dialog.capacity.suggestion.confidence.low": "低置信度", + "model.dialog.capacityCoverage.tag": "缺容量", + "model.dialog.capacityCoverage.warning": "此模型缺少上下文窗口或最大输出Token数。请打开编辑配置补全容量。", + "model.dialog.capacityCoverage.warningWithSuggestion": "此模型缺少容量。编辑弹窗中可能有目录建议可用。", + "model.dialog.capacity.batchDefault.title": "批量默认容量", + "model.dialog.capacity.batchDefault.hint": "此处填写的数值将作为本次批量导入所有 LLM/VLM 模型的默认容量。如需为某个模型单独设置,请点击对应行的⚙图标覆盖。", + "model.dialog.batch.requireRowCapacity": "存在已打开开关的模型缺少上下文窗口或最大输出Token数,请点击对应行的⚙图标补全后再确认。", + "model.dialog.capacity.bulkApply.title": "批量应用容量(可选)", + "model.dialog.capacity.bulkApply.hint": "此处填写的数值将作为本次「修改配置」的批量默认值,应用到当前 provider 下所有该类型模型。留空的字段不会覆盖已有的逐行配置。Tokenizer 因不宜全局统一,需通过单行⚙图标设置。", "model.dialog.modelList.tooltip.settings": "模型设置", "model.dialog.hint.multimodalEnabled": "多模态向量模型可处理图像和文本", "model.dialog.hint.multimodalDisabled": "文本向量模型仅处理文本", @@ -947,6 +1000,9 @@ "modelConfig.button.addCustomModel": "添加模型", "modelConfig.button.editCustomModel": "修改或删除模型", "modelConfig.button.checkConnectivity": "检查模型连通性", + "modelConfig.capacityCoverage.warning": "{{total}} 个 LLM/VLM 模型中有 {{bareCount}} 个缺少容量字段。", + "modelConfig.capacityCoverage.description": "其中 {{suggestionCount}} 个可能有目录建议。打开修改或删除模型,编辑带标记的模型即可修复。", + "modelConfig.capacityCoverage.manage": "管理", "modelConfig.button.sync": "同步", "modelConfig.button.add": "添加", "modelConfig.button.edit": "修改", diff --git a/frontend/services/agentConfigService.ts b/frontend/services/agentConfigService.ts index a955aa410..f1078726b 100644 --- a/frontend/services/agentConfigService.ts +++ b/frontend/services/agentConfigService.ts @@ -248,6 +248,7 @@ export const getCreatingSubAgentId = async () => { modelName: data.model_name, model_id: data.model_id, maxSteps: data.max_steps, + requestedOutputTokens: data.requested_output_tokens ?? null, businessDescription: data.business_description, dutyPrompt: data.duty_prompt, constraintPrompt: data.constraint_prompt, @@ -407,6 +408,7 @@ export interface UpdateAgentInfoPayload { model_name?: string; model_id?: number; max_steps?: number; + requested_output_tokens?: number | null; provide_run_summary?: boolean; enable_context_manager?: boolean; verification_config?: Record; @@ -765,6 +767,7 @@ export const searchAgentInfo = async ( model: data.model_name, model_id: data.model_id, max_step: data.max_steps, + requested_output_tokens: data.requested_output_tokens ?? null, duty_prompt: data.duty_prompt, constraint_prompt: data.constraint_prompt, few_shots_prompt: data.few_shots_prompt, diff --git a/frontend/services/api.ts b/frontend/services/api.ts index e5b4ed025..5779d6ee5 100644 --- a/frontend/services/api.ts +++ b/frontend/services/api.ts @@ -28,7 +28,8 @@ export const API_ENDPOINTS = { pending: `${API_BASE_URL}/user/oauth/pending`, complete: `${API_BASE_URL}/user/oauth/complete`, accounts: `${API_BASE_URL}/user/oauth/accounts`, - unlink: (provider: string) => `${API_BASE_URL}/user/oauth/accounts/${provider}`, + unlink: (provider: string) => + `${API_BASE_URL}/user/oauth/accounts/${provider}`, }, cas: { config: `${API_BASE_URL}/user/cas/config`, @@ -63,18 +64,27 @@ export const API_ENDPOINTS = { regenerateNameBatch: `${API_BASE_URL}/agent/regenerate_name`, searchInfo: `${API_BASE_URL}/agent/search_info`, callRelationship: `${API_BASE_URL}/agent/call_relationship`, - byName: (agentName: string) => `${API_BASE_URL}/agent/by-name/${encodeURIComponent(agentName)}`, - clearNew: (agentId: string | number) => `${API_BASE_URL}/agent/clear_new/${agentId}`, + byName: (agentName: string) => + `${API_BASE_URL}/agent/by-name/${encodeURIComponent(agentName)}`, + clearNew: (agentId: string | number) => + `${API_BASE_URL}/agent/clear_new/${agentId}`, publish: (agentId: number) => `${API_BASE_URL}/agent/${agentId}/publish`, versions: { - version: (agentId: number, versionNo: number) => `${API_BASE_URL}/agent/${agentId}/versions/${versionNo}`, - detail: (agentId: number, versionNo: number) => `${API_BASE_URL}/agent/${agentId}/versions/${versionNo}/detail`, + version: (agentId: number, versionNo: number) => + `${API_BASE_URL}/agent/${agentId}/versions/${versionNo}`, + detail: (agentId: number, versionNo: number) => + `${API_BASE_URL}/agent/${agentId}/versions/${versionNo}/detail`, list: (agentId: number) => `${API_BASE_URL}/agent/${agentId}/versions`, - current: (agentId: number) => `${API_BASE_URL}/agent/${agentId}/current_version`, - rollback: (agentId: number, versionNo: number) => `${API_BASE_URL}/agent/${agentId}/versions/${versionNo}/rollback`, - compare: (agentId: number) => `${API_BASE_URL}/agent/${agentId}/versions/compare`, - delete: (agentId: number, versionNo: number) => `${API_BASE_URL}/agent/${agentId}/versions/${versionNo}`, - update: (agentId: number, versionNo: number) => `${API_BASE_URL}/agent/${agentId}/versions/${versionNo}`, + current: (agentId: number) => + `${API_BASE_URL}/agent/${agentId}/current_version`, + rollback: (agentId: number, versionNo: number) => + `${API_BASE_URL}/agent/${agentId}/versions/${versionNo}/rollback`, + compare: (agentId: number) => + `${API_BASE_URL}/agent/${agentId}/versions/compare`, + delete: (agentId: number, versionNo: number) => + `${API_BASE_URL}/agent/${agentId}/versions/${versionNo}`, + update: (agentId: number, versionNo: number) => + `${API_BASE_URL}/agent/${agentId}/versions/${versionNo}`, }, }, tool: { @@ -97,10 +107,13 @@ export const API_ENDPOINTS = { }, promptTemplates: { list: `${API_BASE_URL}/prompt_templates`, - detail: (templateId: number) => `${API_BASE_URL}/prompt_templates/${templateId}`, + detail: (templateId: number) => + `${API_BASE_URL}/prompt_templates/${templateId}`, create: `${API_BASE_URL}/prompt_templates`, - update: (templateId: number) => `${API_BASE_URL}/prompt_templates/${templateId}`, - delete: (templateId: number) => `${API_BASE_URL}/prompt_templates/${templateId}`, + update: (templateId: number) => + `${API_BASE_URL}/prompt_templates/${templateId}`, + delete: (templateId: number) => + `${API_BASE_URL}/prompt_templates/${templateId}`, }, stt: { ws: `/api/voice/stt/ws`, @@ -170,6 +183,8 @@ export const API_ENDPOINTS = { displayName )}&model_type=${encodeURIComponent(modelType)}`, verifyModelConfig: `${API_BASE_URL}/model/temporary_healthcheck`, + suggestCapacity: `${API_BASE_URL}/model/suggest-capacity`, + capacityCoverage: `${API_BASE_URL}/model/capacity-coverage`, updateSingleModel: (displayName: string) => `${API_BASE_URL}/model/update?display_name=${encodeURIComponent(displayName)}`, updateBatchModel: `${API_BASE_URL}/model/batch_update`, @@ -284,25 +299,35 @@ export const API_ENDPOINTS = { // External agent management agents: `${API_BASE_URL}/a2a/client/agents`, agent: (agentId: string) => `${API_BASE_URL}/a2a/client/agents/${agentId}`, - agentRefresh: (agentId: string) => `${API_BASE_URL}/a2a/client/agents/${agentId}/refresh`, - agentProtocol: (agentId: string) => `${API_BASE_URL}/a2a/client/agents/${agentId}/protocol`, + agentRefresh: (agentId: string) => + `${API_BASE_URL}/a2a/client/agents/${agentId}/refresh`, + agentProtocol: (agentId: string) => + `${API_BASE_URL}/a2a/client/agents/${agentId}/protocol`, // External agent relations relations: `${API_BASE_URL}/a2a/client/relations`, relation: (localAgentId: number, externalAgentId: number) => `${API_BASE_URL}/a2a/client/relations?local_agent_id=${localAgentId}&external_agent_id=${externalAgentId}`, - subAgents: (localAgentId: number) => `${API_BASE_URL}/a2a/client/sub-agents/${localAgentId}`, - externalRelations: (localAgentId: number) => `${API_BASE_URL}/a2a/client/relations/${localAgentId}`, + subAgents: (localAgentId: number) => + `${API_BASE_URL}/a2a/client/sub-agents/${localAgentId}`, + externalRelations: (localAgentId: number) => + `${API_BASE_URL}/a2a/client/relations/${localAgentId}`, // Nacos config management nacosConfigs: `${API_BASE_URL}/a2a/client/nacos-configs`, - nacosConfig: (configId: string) => `${API_BASE_URL}/a2a/client/nacos-configs/${configId}`, + nacosConfig: (configId: string) => + `${API_BASE_URL}/a2a/client/nacos-configs/${configId}`, nacosTestConnection: `${API_BASE_URL}/a2a/client/nacos-configs/test-connection`, // A2A Server management serverAgents: `${API_BASE_URL}/a2a/management/agents`, - serverAgent: (agentId: number) => `${API_BASE_URL}/a2a/management/agents/${agentId}`, - serverAgentEnable: (agentId: number) => `${API_BASE_URL}/a2a/management/agents/${agentId}/enable`, - serverAgentDisable: (agentId: number) => `${API_BASE_URL}/a2a/management/agents/${agentId}/disable`, - serverAgentSettings: (agentId: number) => `${API_BASE_URL}/a2a/management/agents/${agentId}/settings`, - agentChat: (agentId: string) => `${API_BASE_URL}/a2a/client/agents/${agentId}/chat`, + serverAgent: (agentId: number) => + `${API_BASE_URL}/a2a/management/agents/${agentId}`, + serverAgentEnable: (agentId: number) => + `${API_BASE_URL}/a2a/management/agents/${agentId}/enable`, + serverAgentDisable: (agentId: number) => + `${API_BASE_URL}/a2a/management/agents/${agentId}/disable`, + serverAgentSettings: (agentId: number) => + `${API_BASE_URL}/a2a/management/agents/${agentId}/settings`, + agentChat: (agentId: string) => + `${API_BASE_URL}/a2a/client/agents/${agentId}/chat`, }, skills: { list: `${API_BASE_URL}/skills`, @@ -310,9 +335,11 @@ export const API_ENDPOINTS = { upload: `${API_BASE_URL}/skills/upload`, get: (skillName: string) => `${API_BASE_URL}/skills/${skillName}`, update: (skillName: string) => `${API_BASE_URL}/skills/${skillName}`, - updateUpload: (skillName: string) => `${API_BASE_URL}/skills/${skillName}/upload`, + updateUpload: (skillName: string) => + `${API_BASE_URL}/skills/${skillName}/upload`, delete: (skillName: string) => `${API_BASE_URL}/skills/${skillName}`, - deleteFile: (skillName: string, filePath: string) => `${API_BASE_URL}/skills/${skillName}/files/${filePath}`, + deleteFile: (skillName: string, filePath: string) => + `${API_BASE_URL}/skills/${skillName}/files/${filePath}`, files: (skillName: string) => `${API_BASE_URL}/skills/${skillName}/files`, fileContent: (skillName: string, filePath: string) => `${API_BASE_URL}/skills/${skillName}/files/${filePath}`, @@ -540,7 +567,6 @@ export const fetchWithErrorHandling = async ( } }; - // Add global interface extensions for TypeScript declare global { interface Window { diff --git a/frontend/services/modelService.ts b/frontend/services/modelService.ts index 6f82fc2de..d054a9274 100644 --- a/frontend/services/modelService.ts +++ b/frontend/services/modelService.ts @@ -8,6 +8,8 @@ import { ModelConnectStatus, ModelValidationResponse, ModelSource, + CapacitySuggestion, + CapacityCoverage, } from "@/types/modelConfig"; import { getAuthHeaders } from "@/lib/auth"; @@ -24,9 +26,88 @@ import { } from "@/const/modelConfig"; import log from "@/lib/logger"; +const mapCapacityFieldsFromApi = (model: any) => ({ + contextWindowTokens: model.context_window_tokens, + maxInputTokens: model.max_input_tokens, + maxOutputTokens: model.max_output_tokens, + defaultOutputReserveTokens: model.default_output_reserve_tokens, + tokenizerFamily: model.tokenizer_family, + capacitySource: model.capacity_source, + capabilityProfileVersion: model.capability_profile_version, +}); + +const buildCapacityRequestBody = (model: { + contextWindowTokens?: number; + maxInputTokens?: number; + maxOutputTokens?: number; + defaultOutputReserveTokens?: number; + tokenizerFamily?: string; + capacitySource?: string; +}) => ({ + ...(model.contextWindowTokens !== undefined + ? { context_window_tokens: model.contextWindowTokens } + : {}), + ...(model.maxInputTokens !== undefined + ? { max_input_tokens: model.maxInputTokens } + : {}), + ...(model.maxOutputTokens !== undefined + ? { max_output_tokens: model.maxOutputTokens } + : {}), + ...(model.defaultOutputReserveTokens !== undefined + ? { default_output_reserve_tokens: model.defaultOutputReserveTokens } + : {}), + ...(model.tokenizerFamily !== undefined + ? { tokenizer_family: model.tokenizerFamily } + : {}), + ...(model.capacitySource !== undefined + ? { capacity_source: model.capacitySource } + : {}), +}); + +const mapCapacitySuggestionFromApi = ( + suggestion: any +): CapacitySuggestion | null => { + if (!suggestion) return null; + return { + suggestions: suggestion.suggestions + ? { + contextWindowTokens: suggestion.suggestions.context_window_tokens, + maxInputTokens: suggestion.suggestions.max_input_tokens, + maxOutputTokens: suggestion.suggestions.max_output_tokens, + defaultOutputReserveTokens: + suggestion.suggestions.default_output_reserve_tokens, + tokenizerFamily: suggestion.suggestions.tokenizer_family, + } + : null, + matchKind: suggestion.match_kind, + matchConfidence: suggestion.match_confidence, + matchExplanation: suggestion.match_explanation || "", + suggestedProvider: suggestion.suggested_provider, + canonicalModelName: suggestion.canonical_model_name, + capabilityProfileVersion: suggestion.capability_profile_version, + capacitySourceOnAccept: suggestion.capacity_source_on_accept, + }; +}; + +const mapCapacityCoverageFromApi = (coverage: any): CapacityCoverage => ({ + totalLlmVlm: coverage?.total_llm_vlm || 0, + bareCount: coverage?.bare_count || 0, + bareModels: (coverage?.bare_models || []).map((model: any) => ({ + modelId: model.model_id, + modelName: model.model_name, + modelFactory: model.model_factory, + modelType: model.model_type, + maxTokens: model.max_tokens, + suggestionAvailable: Boolean(model.suggestion_available), + })), +}); + // Error class export class ModelError extends Error { - constructor(message: string, public code?: number) { + constructor( + message: string, + public code?: number + ) { super(message); this.name = "ModelError"; // Override the stack property to only return the message @@ -68,6 +149,7 @@ export const modelService = { expectedChunkSize: model.expected_chunk_size, maximumChunkSize: model.maximum_chunk_size, chunkingBatchSize: model.chunk_batch, + ...mapCapacityFieldsFromApi(model), // STT specific fields modelAppid: model.model_appid, accessToken: model.access_token, @@ -110,6 +192,12 @@ export const modelService = { accessToken?: string; timeoutSeconds?: number; concurrencyLimit?: number; + contextWindowTokens?: number; + maxInputTokens?: number; + maxOutputTokens?: number; + defaultOutputReserveTokens?: number; + tokenizerFamily?: string; + capacitySource?: string; }): Promise => { try { const requestBody: any = { @@ -125,6 +213,7 @@ export const modelService = { chunk_batch: model.chunkingBatchSize, timeout_seconds: model.timeoutSeconds, concurrency_limit: model.concurrencyLimit, + ...buildCapacityRequestBody(model), }; // Add STT specific fields @@ -294,7 +383,9 @@ export const modelService = { log.log("getManageProviderModelList result", result); if (response.status !== 200) { throw new ModelError( - result.detail || result.message || "Failed to get provider model list", + result.detail || + result.message || + "Failed to get provider model list", response.status ); } @@ -308,6 +399,7 @@ export const modelService = { updateSingleModel: async (model: { currentDisplayName: string; + name?: string; displayName?: string; url: string; apiKey: string; @@ -322,6 +414,12 @@ export const modelService = { accessToken?: string; timeoutSeconds?: number; concurrencyLimit?: number; + contextWindowTokens?: number; + maxInputTokens?: number; + maxOutputTokens?: number; + defaultOutputReserveTokens?: number; + tokenizerFamily?: string; + capacitySource?: string; }): Promise => { try { const response = await fetch( @@ -333,6 +431,7 @@ export const modelService = { ...(model.displayName !== undefined ? { display_name: model.displayName } : {}), + ...(model.name !== undefined ? { model_name: model.name } : {}), base_url: model.url, api_key: model.apiKey, ...(model.maxTokens !== undefined @@ -362,14 +461,17 @@ export const modelService = { : {}), ...(model.concurrencyLimit !== undefined ? { concurrency_limit: model.concurrencyLimit } - : {}) + : {}), + ...buildCapacityRequestBody(model), }), } ); const result = await response.json(); if (response.status !== 200) { throw new ModelError( - result.detail || result.message || "Failed to update the custom model", + result.detail || + result.message || + "Failed to update the custom model", response.status ); } @@ -386,6 +488,12 @@ export const modelService = { maxTokens?: number; timeoutSeconds?: number; concurrencyLimit?: number; + contextWindowTokens?: number; + maxInputTokens?: number; + maxOutputTokens?: number; + defaultOutputReserveTokens?: number; + tokenizerFamily?: string; + capacitySource?: string; }[], provider?: string ): Promise => { @@ -398,8 +506,30 @@ export const modelService = { model_id: m.model_id, api_key: m.apiKey, ...(m.maxTokens !== undefined ? { max_tokens: m.maxTokens } : {}), - ...(m.timeoutSeconds !== undefined ? { timeout_seconds: m.timeoutSeconds } : {}), - ...(m.concurrencyLimit !== undefined ? { concurrency_limit: m.concurrencyLimit } : {}), + ...(m.timeoutSeconds !== undefined + ? { timeout_seconds: m.timeoutSeconds } + : {}), + ...(m.concurrencyLimit !== undefined + ? { concurrency_limit: m.concurrencyLimit } + : {}), + ...(m.contextWindowTokens !== undefined + ? { context_window_tokens: m.contextWindowTokens } + : {}), + ...(m.maxInputTokens !== undefined + ? { max_input_tokens: m.maxInputTokens } + : {}), + ...(m.maxOutputTokens !== undefined + ? { max_output_tokens: m.maxOutputTokens } + : {}), + ...(m.defaultOutputReserveTokens !== undefined + ? { default_output_reserve_tokens: m.defaultOutputReserveTokens } + : {}), + ...(m.tokenizerFamily !== undefined + ? { tokenizer_family: m.tokenizerFamily } + : {}), + ...(m.capacitySource !== undefined + ? { capacity_source: m.capacitySource } + : {}), ...(provider ? { model_factory: provider } : {}), })) ), @@ -407,7 +537,9 @@ export const modelService = { const result = await response.json(); if (response.status !== 200) { throw new ModelError( - result.detail || result.message || "Failed to update the custom model", + result.detail || + result.message || + "Failed to update the custom model", response.status ); } @@ -494,7 +626,7 @@ export const modelService = { body: JSON.stringify({ tenant_id: tenantId, display_name: displayName, - model_type: modelType + model_type: modelType, }), signal, }); @@ -535,7 +667,9 @@ export const modelService = { model_type: config.modelType, api_key: config.apiKey || "sk-no-api-key", base_url: config.baseUrl || "", - ...(config.maxTokens !== undefined ? { max_tokens: config.maxTokens } : {}), + ...(config.maxTokens !== undefined + ? { max_tokens: config.maxTokens } + : {}), embedding_dim: config.embeddingDim || 1024, }; @@ -563,14 +697,21 @@ export const modelService = { return { connectivity: result.data.connectivity, model_name: result.data.model_name || "UNKNOWN_MODEL", - error: result.data.connectivity ? undefined : result.data.error || result.detail || result.message, + error: result.data.connectivity + ? undefined + : result.data.error || result.detail || result.message, + capacitySuggestion: mapCapacitySuggestionFromApi( + result.data.capacity_suggestion + ), }; } return { connectivity: false, model_name: result.data?.model_name || "UNKNOWN_MODEL", - error: result.detail || result.message || "Connection verification failed", + error: + result.detail || result.message || "Connection verification failed", + capacitySuggestion: null, }; } catch (error) { if (error instanceof Error && error.name === "AbortError") { @@ -582,10 +723,71 @@ export const modelService = { connectivity: false, model_name: "UNKNOWN_MODEL", error: error instanceof Error ? error.message : String(error), + capacitySuggestion: null, }; } }, + suggestCapacity: async (params: { + modelName: string; + baseUrl?: string; + providerHint?: string; + apiKey?: string; + modelType?: ModelType; + }): Promise => { + try { + const response = await fetch(API_ENDPOINTS.model.suggestCapacity, { + method: "POST", + headers: getAuthHeaders(), + body: JSON.stringify({ + model_name: params.modelName, + ...(params.baseUrl ? { base_url: params.baseUrl } : {}), + ...(params.providerHint + ? { provider_hint: params.providerHint } + : {}), + ...(params.apiKey ? { api_key: params.apiKey } : {}), + ...(params.modelType ? { model_type: params.modelType } : {}), + }), + }); + + const result = await response.json(); + if (response.status !== STATUS_CODES.SUCCESS || !result.data) { + throw new ModelError( + result.detail || result.message || "Failed to suggest model capacity", + response.status + ); + } + const mapped = mapCapacitySuggestionFromApi(result.data); + if (!mapped) { + throw new ModelError( + "Failed to suggest model capacity", + response.status + ); + } + return mapped; + } catch (error) { + if (error instanceof ModelError) throw error; + log.warn("Failed to suggest model capacity:", error); + throw new ModelError("Failed to suggest model capacity", 500); + } + }, + + getCapacityCoverage: async (): Promise => { + try { + const response = await fetch(API_ENDPOINTS.model.capacityCoverage, { + headers: getAuthHeaders(), + }); + const result = await response.json(); + if (response.status !== STATUS_CODES.SUCCESS || !result.data) { + return { totalLlmVlm: 0, bareCount: 0, bareModels: [] }; + } + return mapCapacityCoverageFromApi(result.data); + } catch (error) { + log.warn("Failed to load model capacity coverage:", error); + return { totalLlmVlm: 0, bareCount: 0, bareModels: [] }; + } + }, + // Get LLM model list for generation getLLMModels: async (): Promise => { try { @@ -661,6 +863,7 @@ export const modelService = { expectedChunkSize: model.expected_chunk_size, maximumChunkSize: model.maximum_chunk_size, chunkingBatchSize: model.chunk_batch, + ...mapCapacityFieldsFromApi(model), // STT specific fields modelAppid: model.model_appid, accessToken: model.access_token, @@ -714,6 +917,12 @@ export const modelService = { accessToken?: string; timeoutSeconds?: number; concurrencyLimit?: number; + contextWindowTokens?: number; + maxInputTokens?: number; + maxOutputTokens?: number; + defaultOutputReserveTokens?: number; + tokenizerFamily?: string; + capacitySource?: string; }): Promise => { try { const requestBody: any = { @@ -723,7 +932,9 @@ export const modelService = { model_type: params.type, base_url: params.url, api_key: params.apiKey, - ...(params.maxTokens !== undefined ? { max_tokens: params.maxTokens } : {}), + ...(params.maxTokens !== undefined + ? { max_tokens: params.maxTokens } + : {}), display_name: params.displayName || params.name, model_factory: params.modelFactory || "OpenAI-API-Compatible", expected_chunk_size: params.expectedChunkSize, @@ -731,6 +942,7 @@ export const modelService = { chunk_batch: params.chunkingBatchSize, timeout_seconds: params.timeoutSeconds, concurrency_limit: params.concurrencyLimit, + ...buildCapacityRequestBody(params), }; // Add STT specific fields @@ -756,7 +968,9 @@ export const modelService = { const result = await response.json(); if (response.status !== STATUS_CODES.SUCCESS) { throw new ModelError( - result.detail || result.message || "Failed to create model for tenant", + result.detail || + result.message || + "Failed to create model for tenant", response.status ); } @@ -771,6 +985,7 @@ export const modelService = { updateManageTenantModel: async (params: { tenantId: string; currentDisplayName: string; + name?: string; displayName?: string; url: string; apiKey: string; @@ -784,6 +999,12 @@ export const modelService = { accessToken?: string; timeoutSeconds?: number; concurrencyLimit?: number; + contextWindowTokens?: number; + maxInputTokens?: number; + maxOutputTokens?: number; + defaultOutputReserveTokens?: number; + tokenizerFamily?: string; + capacitySource?: string; }): Promise => { try { const response = await fetch( @@ -797,18 +1018,40 @@ export const modelService = { body: JSON.stringify({ tenant_id: params.tenantId, current_display_name: params.currentDisplayName, - ...(params.displayName !== undefined ? { display_name: params.displayName } : {}), + ...(params.name !== undefined ? { model_name: params.name } : {}), + ...(params.displayName !== undefined + ? { display_name: params.displayName } + : {}), base_url: params.url, api_key: params.apiKey, - ...(params.maxTokens !== undefined ? { max_tokens: params.maxTokens } : {}), - ...(params.expectedChunkSize !== undefined ? { expected_chunk_size: params.expectedChunkSize } : {}), - ...(params.maximumChunkSize !== undefined ? { maximum_chunk_size: params.maximumChunkSize } : {}), - ...(params.chunkingBatchSize !== undefined ? { chunk_batch: params.chunkingBatchSize } : {}), - ...(params.modelFactory !== undefined ? { model_factory: params.modelFactory } : {}), - ...(params.modelAppid !== undefined ? { model_appid: params.modelAppid } : {}), - ...(params.accessToken !== undefined ? { access_token: params.accessToken } : {}), - ...(params.timeoutSeconds !== undefined ? { timeout_seconds: params.timeoutSeconds } : {}), - ...(params.concurrencyLimit !== undefined ? { concurrency_limit: params.concurrencyLimit } : {}), + ...(params.maxTokens !== undefined + ? { max_tokens: params.maxTokens } + : {}), + ...(params.expectedChunkSize !== undefined + ? { expected_chunk_size: params.expectedChunkSize } + : {}), + ...(params.maximumChunkSize !== undefined + ? { maximum_chunk_size: params.maximumChunkSize } + : {}), + ...(params.chunkingBatchSize !== undefined + ? { chunk_batch: params.chunkingBatchSize } + : {}), + ...(params.modelFactory !== undefined + ? { model_factory: params.modelFactory } + : {}), + ...(params.modelAppid !== undefined + ? { model_appid: params.modelAppid } + : {}), + ...(params.accessToken !== undefined + ? { access_token: params.accessToken } + : {}), + ...(params.timeoutSeconds !== undefined + ? { timeout_seconds: params.timeoutSeconds } + : {}), + ...(params.concurrencyLimit !== undefined + ? { concurrency_limit: params.concurrencyLimit } + : {}), + ...buildCapacityRequestBody(params), }), } ); @@ -816,7 +1059,9 @@ export const modelService = { const result = await response.json(); if (response.status !== STATUS_CODES.SUCCESS) { throw new ModelError( - result.detail || result.message || "Failed to update model for tenant", + result.detail || + result.message || + "Failed to update model for tenant", response.status ); } @@ -851,7 +1096,9 @@ export const modelService = { const result = await response.json(); if (response.status !== STATUS_CODES.SUCCESS) { throw new ModelError( - result.detail || result.message || "Failed to delete model for tenant", + result.detail || + result.message || + "Failed to delete model for tenant", response.status ); } @@ -875,7 +1122,12 @@ export const modelService = { owned_by?: string; max_tokens?: number; }>; - }): Promise<{ tenantId: string; provider: string; type: string; modelsCount: number }> => { + }): Promise<{ + tenantId: string; + provider: string; + type: string; + modelsCount: number; + }> => { try { const response = await fetch(API_ENDPOINTS.model.manageModelBatchCreate, { method: "POST", @@ -895,7 +1147,9 @@ export const modelService = { const result = await response.json(); if (response.status !== STATUS_CODES.SUCCESS) { throw new ModelError( - result.detail || result.message || "Failed to batch create models for tenant", + result.detail || + result.message || + "Failed to batch create models for tenant", response.status ); } @@ -921,24 +1175,32 @@ export const modelService = { baseUrl?: string; }): Promise => { try { - const response = await fetch(API_ENDPOINTS.model.manageProviderModelCreate, { - method: "POST", - headers: { - ...getAuthHeaders(), - "Content-Type": "application/json", - }, - body: JSON.stringify({ - tenant_id: params.tenantId, - provider: params.provider, - model_type: params.type, - api_key: params.apiKey, - ...(params.baseUrl ? { base_url: params.baseUrl } : {}), - }), - }); + const response = await fetch( + API_ENDPOINTS.model.manageProviderModelCreate, + { + method: "POST", + headers: { + ...getAuthHeaders(), + "Content-Type": "application/json", + }, + body: JSON.stringify({ + tenant_id: params.tenantId, + provider: params.provider, + model_type: params.type, + api_key: params.apiKey, + ...(params.baseUrl ? { base_url: params.baseUrl } : {}), + }), + } + ); const result = await response.json(); if (response.status !== STATUS_CODES.SUCCESS) { - throw new ModelError(result.detail || result.message || "Failed to create provider models for tenant", response.status); + throw new ModelError( + result.detail || + result.message || + "Failed to create provider models for tenant", + response.status + ); } return result.data || []; } catch (error) { @@ -955,28 +1217,39 @@ export const modelService = { type: ModelType; }): Promise => { try { - const response = await fetch(API_ENDPOINTS.model.manageProviderModelList, { - method: "POST", - headers: { - ...getAuthHeaders(), - "Content-Type": "application/json", - }, - body: JSON.stringify({ - tenant_id: params.tenantId, - provider: params.provider, - model_type: params.type, - }), - }); + const response = await fetch( + API_ENDPOINTS.model.manageProviderModelList, + { + method: "POST", + headers: { + ...getAuthHeaders(), + "Content-Type": "application/json", + }, + body: JSON.stringify({ + tenant_id: params.tenantId, + provider: params.provider, + model_type: params.type, + }), + } + ); const result = await response.json(); if (response.status !== STATUS_CODES.SUCCESS) { - throw new ModelError(result.detail || result.message || "Failed to get provider selected list for tenant", response.status); + throw new ModelError( + result.detail || + result.message || + "Failed to get provider selected list for tenant", + response.status + ); } return result.data || []; } catch (error) { if (error instanceof ModelError) throw error; log.warn("Failed to get manage provider selected list:", error); - throw new ModelError("Failed to get provider selected list for tenant", 500); + throw new ModelError( + "Failed to get provider selected list for tenant", + 500 + ); } }, }; diff --git a/frontend/stores/agentConfigStore.ts b/frontend/stores/agentConfigStore.ts index e1a1b9545..e82832650 100644 --- a/frontend/stores/agentConfigStore.ts +++ b/frontend/stores/agentConfigStore.ts @@ -34,6 +34,7 @@ export type EditableAgent = Pick< | "model" | "model_id" | "max_step" + | "requested_output_tokens" | "provide_run_summary" | "tools" | "duty_prompt" @@ -166,6 +167,7 @@ function createEmptyEditableAgent(llmConfig?: { id: number | null; name: string; model: llmConfig?.name || "", model_id: llmConfig?.id || 0, max_step: 15, + requested_output_tokens: null, provide_run_summary: false, tools: [], skills: [], @@ -198,6 +200,7 @@ const toEditable = (agent: Agent | null): EditableAgent => model: agent.model, model_id: agent.model_id || 0, max_step: agent.max_step, + requested_output_tokens: agent.requested_output_tokens ?? null, provide_run_summary: agent.provide_run_summary, tools: [...(agent.tools || [])], skills: [...(agent.skills || [])], @@ -318,6 +321,7 @@ const isDirty = ( editedAgent.model !== "" || editedAgent.model_id !== 0 || editedAgent.max_step !== 0 || + editedAgent.requested_output_tokens != null || editedAgent.provide_run_summary !== false || editedAgent.duty_prompt !== "" || editedAgent.constraint_prompt !== "" || @@ -348,6 +352,8 @@ const isDirty = ( baselineAgent.model !== editedAgent.model || baselineAgent.model_id !== editedAgent.model_id || baselineAgent.max_step !== editedAgent.max_step || + (baselineAgent.requested_output_tokens ?? null) !== + (editedAgent.requested_output_tokens ?? null) || baselineAgent.provide_run_summary !== editedAgent.provide_run_summary || baselineAgent.duty_prompt !== editedAgent.duty_prompt || baselineAgent.constraint_prompt !== editedAgent.constraint_prompt || diff --git a/frontend/types/agentConfig.ts b/frontend/types/agentConfig.ts index 6b825b28c..9bbf4806d 100644 --- a/frontend/types/agentConfig.ts +++ b/frontend/types/agentConfig.ts @@ -14,6 +14,7 @@ export type AgentConfigUpdate = Partial List[ActionStep]: return [prev_action, last_action] return [last_action] - # ============================================================ - # Mainly Entry Point - # ============================================================ - - def compress_if_needed( - self, model, memory, original_messages: List[ChatMessage], current_run_start_idx, - ) -> List[ChatMessage]: - # G1 - if not self.config.enabled: - return original_messages - - if self._estimate_tokens(memory) <= self.config.token_threshold: - # No compression needed; record that compressed == uncompressed - # so benchmark token_reduction reads as zero rather than stale. - self._last_uncompressed_token_count = self._msg_token_count(original_messages) - self._last_compressed_token_count = self._last_uncompressed_token_count + # ============================================================ + # Mainly Entry Point + # ============================================================ + + def _soft_input_budget_tokens(self) -> int: + return self.config.soft_input_budget_tokens or self.config.token_threshold + + def _hard_input_budget_tokens(self) -> int: + return self.config.hard_input_budget_tokens or int(self.config.token_threshold * 1.1) + + def compress_if_needed( + self, model, memory, original_messages: List[ChatMessage], current_run_start_idx, + ) -> List[ChatMessage]: + # G1 + if not self.config.enabled: + return original_messages + + soft_input_budget_tokens = self._soft_input_budget_tokens() + hard_input_budget_tokens = self._hard_input_budget_tokens() + + if self._estimate_tokens(memory) <= soft_input_budget_tokens: + # No compression needed; record that compressed == uncompressed + # so benchmark token_reduction reads as zero rather than stale. + self._last_uncompressed_token_count = self._msg_token_count(original_messages) + self._last_compressed_token_count = self._last_uncompressed_token_count return original_messages with self._lock: @@ -471,13 +480,13 @@ def compress_if_needed( self._current_summary_cache = None self._last_run_start_idx = current_run_start_idx - # Note: The memory here always consists of the unmodified, summary-task-step-free - # original previous_run + current_run. - # - previous_run: [(TaskStep, ActionStep), ...] - # - current_run: [TaskStep, ActionStep, ActionStep, ...] - if self._effective_tokens(memory, current_run_start_idx) <= self.config.token_threshold: - # Stable-phase bypass: No LLM call; construct compressed messages directly from existing cache. - self._step_local_log.clear() + # Note: The memory here always consists of the unmodified, summary-task-step-free + # original previous_run + current_run. + # - previous_run: [(TaskStep, ActionStep), ...] + # - current_run: [TaskStep, ActionStep, ActionStep, ...] + if self._effective_tokens(memory, current_run_start_idx) <= soft_input_budget_tokens: + # Stable-phase bypass: No LLM call; construct compressed messages directly from existing cache. + self._step_local_log.clear() prev_steps = memory.steps[:current_run_start_idx] curr_steps = memory.steps[current_run_start_idx:] @@ -529,20 +538,21 @@ def compress_if_needed( prev_steps = memory.steps[:current_run_start_idx] curr_steps = memory.steps[current_run_start_idx:] - prev_tokens = self._effective_prev_tokens(prev_steps) - curr_tokens = self._effective_curr_tokens(curr_steps) - - compress_prev = prev_tokens > self.config.token_threshold * 0.6 - compress_curr = curr_tokens > self.config.token_threshold * 0.4 - - total_effective_tokens = prev_tokens + curr_tokens - if compress_prev or compress_curr: - logger.info( - f"Context compression triggered: total_tokens={total_effective_tokens}, " - f"threshold={self.config.token_threshold}, " - f"prev_tokens={prev_tokens} (compress={compress_prev}), " - f"curr_tokens={curr_tokens} (compress={compress_curr})" - ) + prev_tokens = self._effective_prev_tokens(prev_steps) + curr_tokens = self._effective_curr_tokens(curr_steps) + + compress_prev = prev_tokens > soft_input_budget_tokens * 0.6 + compress_curr = curr_tokens > soft_input_budget_tokens * 0.4 + + total_effective_tokens = prev_tokens + curr_tokens + if compress_prev or compress_curr: + logger.info( + f"Context compression triggered: total_tokens={total_effective_tokens}, " + f"soft_budget={soft_input_budget_tokens}, " + f"hard_budget={hard_input_budget_tokens}, " + f"prev_tokens={prev_tokens} (compress={compress_prev}), " + f"curr_tokens={curr_tokens} (compress={compress_curr})" + ) # --------------- Previous phase --------------- prev_summary_step: Optional[SummaryTaskStep] = None @@ -622,15 +632,15 @@ def compress_if_needed( final_messages = self._build_messages( memory, prev_summary_step, prev_tail_steps, curr_kept_steps ) - final_tokens = self._msg_token_count(final_messages) - self._last_compressed_token_count = final_tokens - # This situation is unlikely to occur unless the threshold itself is set unreasonably small - if final_tokens > int(self.config.token_threshold * 1.1): - logger.warning( - f"Still exceeds threshold after compression: {final_tokens} > {self.config.token_threshold}. " - f"Consider reducing keep_recent_pairs ({self.config.keep_recent_pairs}) " - f"or keep_recent_steps({self.config.keep_recent_steps})" - ) + final_tokens = self._msg_token_count(final_messages) + self._last_compressed_token_count = final_tokens + # This situation is unlikely to occur unless the threshold itself is set unreasonably small + if final_tokens > hard_input_budget_tokens: + logger.warning( + f"Still exceeds hard input budget after compression: {final_tokens} > {hard_input_budget_tokens}. " + f"Consider reducing keep_recent_pairs ({self.config.keep_recent_pairs}) " + f"or keep_recent_steps({self.config.keep_recent_steps})" + ) return final_messages # ============================================================ @@ -1406,4 +1416,4 @@ def _message_already_present(self, messages: List, new_msg: dict) -> bool: for existing in messages: if existing.get("role") == new_msg.get("role") and existing.get("content") == new_msg.get("content"): return True - return False \ No newline at end of file + return False diff --git a/sdk/nexent/core/agents/agent_model.py b/sdk/nexent/core/agents/agent_model.py index 62e75cb59..cad66256d 100644 --- a/sdk/nexent/core/agents/agent_model.py +++ b/sdk/nexent/core/agents/agent_model.py @@ -12,7 +12,7 @@ PROTOCOL_HTTP_JSON = "HTTP+JSON" PROTOCOL_GRPC = "GRPC" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from ..utils.observer import MessageObserver @@ -44,16 +44,49 @@ class ModelConfig(BaseModel): ), default=None, ) - max_tokens: Optional[int] = Field( + max_output_tokens: Optional[int] = Field( description=( "Per-call completion output cap forwarded to chat.completions.create. " - "Defaults to None so production keeps the provider's own default " - "(typically the model's max output). Benchmarks set this explicitly " - "(e.g. 4096) to bound pathological generation loops where a model " - "regurgitates context." + "Preferred name over the deprecated max_tokens. Defaults to None so " + "production keeps the provider's own default (typically the model's " + "max output). Benchmarks set this explicitly (e.g. 4096) to bound " + "pathological generation loops where a model regurgitates context." + ), + default=None, + ) + max_tokens: Optional[int] = Field( + description=( + "DEPRECATED W1 alias for max_output_tokens. Retained so existing " + "callers and persisted ModelRecord rows keep working during the " + "migration window. If only max_tokens is set, the validator copies " + "it into max_output_tokens; if both are set, max_output_tokens wins." ), default=None, ) + context_window_tokens: Optional[int] = Field( + description="Total combined input/output context window in tokens, when the provider uses a combined window. Resolved by ModelCapacityResolver per W1 ADR.", + default=None, + ) + max_input_tokens: Optional[int] = Field( + description="Provider hard input-token limit when distinct from the combined window. Resolved by ModelCapacityResolver per W1 ADR.", + default=None, + ) + default_output_reserve_tokens: Optional[int] = Field( + description="Default output allowance reserved per request before constructing input context. Resolved by ModelCapacityResolver per W1 ADR.", + default=None, + ) + tokenizer_family: Optional[str] = Field( + description="Tokenizer-family identifier resolved via tokenizer_registry. None forces estimated counting mode.", + default=None, + ) + capacity_source: Optional[str] = Field( + description="Source of the persisted capacity value: operator | profile | provider_candidate | legacy | unknown.", + default=None, + ) + capability_profile_version: Optional[str] = Field( + description="Version of the approved provider/model capability profile selected by the resolver, e.g. 'openai/gpt-4o@1'.", + default=None, + ) timeout_seconds: Optional[float] = Field( description="Request timeout in seconds. If None, uses provider default.", default=None @@ -63,6 +96,15 @@ class ModelConfig(BaseModel): default=None, ) + @model_validator(mode="after") + def _backfill_max_output_from_legacy_max_tokens(self) -> "ModelConfig": + if self.max_output_tokens is None and self.max_tokens is not None: + self.max_output_tokens = self.max_tokens + elif self.max_output_tokens is not None and self.max_tokens is None: + # Keep legacy attribute populated so callers reading it keep working. + self.max_tokens = self.max_output_tokens + return self + class ToolConfig(BaseModel): class_name: str = Field(description="Tool class name") @@ -142,6 +184,14 @@ class AgentConfig(BaseModel): prompt_templates: Optional[Dict[str, Any]] = Field(description="Prompt templates", default=None) tools: List[ToolConfig] = Field(description="List of tool information") max_steps: int = Field(description="Maximum number of steps for current Agent", default=15, ge=1, le=30) + requested_output_tokens: Optional[int] = Field( + description=( + "Per-agent W2 output reserve override. None means inherit the " + "resolved model-level default." + ), + default=None, + ge=1, + ) model_name: str = Field(description="Model alias from ModelConfig") provide_run_summary: Optional[bool] = Field(description="Whether to provide run summary to upper-level Agent", default=False) instructions: Optional[str] = Field(description="Additional instructions to prepend to system prompt", default=None) @@ -161,6 +211,14 @@ class AgentConfig(BaseModel): description="Pre-built context components for system prompt assembly", default=None ) + capacity_snapshot: Optional[Dict[str, Any]] = Field( + description="Resolved model capacity snapshot fields for request monitoring", + default=None, + ) + safe_input_budget_snapshot: Optional[Dict[str, Any]] = Field( + description="Resolved W2 safe input budget snapshot for request execution", + default=None, + ) verification_config: AgentVerificationConfig = Field( description="Layered ReAct self-verification configuration", default_factory=AgentVerificationConfig, @@ -192,6 +250,14 @@ class AgentRunInfo(BaseModel): "If provided, it will be attached to the CoreAgent instead of creating a new one.", default=None ) + capacity_snapshot: Optional[Dict[str, Any]] = Field( + description="Resolved model capacity snapshot fields for request monitoring", + default=None, + ) + safe_input_budget_snapshot: Optional[Dict[str, Any]] = Field( + description="Resolved W2 safe input budget snapshot for request execution", + default=None, + ) class Config: arbitrary_types_allowed = True diff --git a/sdk/nexent/core/agents/nexent_agent.py b/sdk/nexent/core/agents/nexent_agent.py index ed43b6691..c3b70793c 100644 --- a/sdk/nexent/core/agents/nexent_agent.py +++ b/sdk/nexent/core/agents/nexent_agent.py @@ -183,7 +183,7 @@ def create_model(self, model_cite_name: str): model_factory=model_config.model_factory, display_name=model_config.cite_name, extra_body=model_config.extra_body, - max_tokens=model_config.max_tokens, + max_output_tokens=model_config.max_output_tokens, timeout_seconds=model_config.timeout_seconds, ) model.stop_event = self.stop_event @@ -387,6 +387,16 @@ def create_single_agent(self, agent_config: AgentConfig): try: model = self.create_model(agent_config.model_name) + model.safe_input_budget_snapshot = getattr( + agent_config, + "safe_input_budget_snapshot", + None, + ) + model.capacity_snapshot = getattr( + agent_config, + "capacity_snapshot", + None, + ) prompt_templates = agent_config.prompt_templates try: diff --git a/sdk/nexent/core/agents/run_agent.py b/sdk/nexent/core/agents/run_agent.py index 243ca099e..1d050f066 100644 --- a/sdk/nexent/core/agents/run_agent.py +++ b/sdk/nexent/core/agents/run_agent.py @@ -1,4 +1,5 @@ import asyncio +import json import logging from contextvars import copy_context from threading import Thread @@ -6,6 +7,10 @@ from smolagents import ToolCollection +from ...monitor import ( + set_monitoring_capacity_snapshot, + set_monitoring_safe_input_budget_snapshot, +) from .agent_model import AgentRunInfo from .nexent_agent import NexentAgent, ProcessType @@ -13,6 +18,43 @@ logger.setLevel(logging.DEBUG) +def _emit_uncertainty_reserve_warning(agent_run_info: AgentRunInfo) -> None: + snapshot = getattr(agent_run_info, "safe_input_budget_snapshot", None) + if not isinstance(snapshot, dict): + return + warnings = snapshot.get("warnings") or [] + if "uncertainty_reserve_active" not in warnings: + return + + payload = { + "code": "uncertainty_reserve_active", + "message": ( + "W2 applied the unified 10% uncertainty reserve because selected " + "model capability behavior is not fully verified." + ), + "budget_fingerprint": snapshot.get("fingerprint"), + "w1_fingerprint": snapshot.get("w1_fingerprint"), + "uncertainty_reserve_tokens": snapshot.get("uncertainty_reserve_tokens"), + "hard_input_budget_tokens": snapshot.get("hard_input_budget_tokens"), + } + logger.warning( + "W2 uncertainty reserve active: budget_fingerprint=%s w1_fingerprint=%s " + "uncertainty_reserve_tokens=%s hard_input_budget_tokens=%s", + payload["budget_fingerprint"], + payload["w1_fingerprint"], + payload["uncertainty_reserve_tokens"], + payload["hard_input_budget_tokens"], + ) + try: + agent_run_info.observer.add_message( + "", + ProcessType.OTHER, + json.dumps(payload, ensure_ascii=False), + ) + except Exception: + logger.debug("Failed to emit W2 uncertainty reserve observer warning", exc_info=True) + + def _detect_transport(url: str) -> str: """ Auto-detect MCP transport type based on URL format. @@ -76,6 +118,13 @@ def _normalize_mcp_config(mcp_host_item: Union[str, Dict[str, Any]]) -> Dict[str def agent_run_thread(agent_run_info: AgentRunInfo): try: + set_monitoring_capacity_snapshot( + getattr(agent_run_info, "capacity_snapshot", None) + ) + set_monitoring_safe_input_budget_snapshot( + getattr(agent_run_info, "safe_input_budget_snapshot", None) + ) + _emit_uncertainty_reserve_warning(agent_run_info) mcp_host = agent_run_info.mcp_host if mcp_host is None or len(mcp_host) == 0: nexent = NexentAgent( diff --git a/sdk/nexent/core/agents/summary_config.py b/sdk/nexent/core/agents/summary_config.py index e271ddd34..294bc9eaf 100644 --- a/sdk/nexent/core/agents/summary_config.py +++ b/sdk/nexent/core/agents/summary_config.py @@ -19,6 +19,8 @@ class ContextManagerConfig: # === Compression Settings (existing) === enabled: bool = False token_threshold: int = 10000 + soft_input_budget_tokens: int = 0 + hard_input_budget_tokens: int = 0 keep_recent_steps: int = 4 keep_recent_pairs: int = 2 max_chunk_count: int = 0 @@ -118,4 +120,4 @@ class ContextManagerConfig: # === NEW: Buffered Strategy Settings === buffer_size_per_component: int = 10 - """Number of items to keep per component type for 'buffered' strategy.""" \ No newline at end of file + """Number of items to keep per component type for 'buffered' strategy.""" diff --git a/sdk/nexent/core/models/__init__.py b/sdk/nexent/core/models/__init__.py index 9d8217358..a3d265fba 100644 --- a/sdk/nexent/core/models/__init__.py +++ b/sdk/nexent/core/models/__init__.py @@ -7,6 +7,28 @@ from .tts_model import BaseTTSModel from .ali_tts_model import AliTTSModel, AliTTSConfig from .volc_tts_model import VolcTTSModel, VolcTTSConfig +from .capacity_resolver import ( + CapabilityProfile, + ModelCapacitySnapshot, + ProfileKey, + ResolverError, + RESOLVER_VERSION, + compute_fingerprint, + resolve_capacity, +) +from .capacity_budget import ( + BudgetResolverError, + CallerMaxTokensOverrideForbidden, + CapacityReservePolicy, + RequestBudgetOverrides, + SafeInputBudgetCalculator, + SafeInputBudgetCapacityMismatch, + SafeInputBudgetFingerprintMismatch, + SafeInputBudgetSnapshot, + W2_RESOLVER_VERSION, + compute_w2_fingerprint, +) +from . import tokenizer_registry __all__ = [ "OpenAIModel", @@ -22,4 +44,22 @@ "AliTTSConfig", "VolcTTSModel", "VolcTTSConfig", + "CapabilityProfile", + "ModelCapacitySnapshot", + "ProfileKey", + "ResolverError", + "RESOLVER_VERSION", + "compute_fingerprint", + "resolve_capacity", + "BudgetResolverError", + "CallerMaxTokensOverrideForbidden", + "CapacityReservePolicy", + "RequestBudgetOverrides", + "SafeInputBudgetCalculator", + "SafeInputBudgetCapacityMismatch", + "SafeInputBudgetFingerprintMismatch", + "SafeInputBudgetSnapshot", + "W2_RESOLVER_VERSION", + "compute_w2_fingerprint", + "tokenizer_registry", ] diff --git a/sdk/nexent/core/models/capacity_budget.py b/sdk/nexent/core/models/capacity_budget.py new file mode 100644 index 000000000..5eb1a0d02 --- /dev/null +++ b/sdk/nexent/core/models/capacity_budget.py @@ -0,0 +1,385 @@ +from __future__ import annotations + +import hashlib +import json +import math +from typing import Any, Literal, Mapping, Optional, Sequence + +from pydantic import BaseModel, ConfigDict, Field + +from .capacity_resolver import ModelCapacitySnapshot + + +W2_RESOLVER_VERSION = "1.0.0" +W2_FINGERPRINT_SCHEMA_VERSION = 1 + + +OutputReserveSource = Literal["model_default", "agent", "request"] +UncertaintyReserveBasis = Literal[ + "context_window_10pct", "approved_profile", "none" +] +SoftLimitRatioSource = Literal["code_default", "tenant_config"] +BudgetFieldSource = Literal[ + "model_default", + "agent", + "request", + "code_default", + "tenant_config", + "approved_profile", + "derived", +] + + +class BudgetResolverError(Exception): + """Base class for W2 safe-input-budget resolution failures.""" + + +class InvalidReservePolicy(BudgetResolverError): + pass + + +class RequestedOutputExceedsCapacity(BudgetResolverError): + pass + + +class UncertaintyReserveBasisUnknown(BudgetResolverError): + pass + + +class ReserveExceedsCapacity(BudgetResolverError): + pass + + +class NoSafeInputCapacity(BudgetResolverError): + pass + + +class SafeInputBudgetFingerprintMismatch(BudgetResolverError): + """Raised when a W2 snapshot fingerprint does not match its payload.""" + + def __init__(self, *, expected: str, actual: str) -> None: + self.expected = expected + self.actual = actual + super().__init__( + "safe_input_budget_fingerprint_mismatch: " + f"expected={expected} actual={actual}" + ) + + +class CallerMaxTokensOverrideForbidden(BudgetResolverError): + """Raised when a caller tries to override W2's trusted output cap.""" + + def __init__(self, *, snapshot_value: int, caller_value: int) -> None: + self.snapshot_value = snapshot_value + self.caller_value = caller_value + super().__init__( + "caller_max_tokens_override_forbidden: " + f"caller max_tokens={caller_value} does not match " + f"requested_output_tokens={snapshot_value}" + ) + + +class SafeInputBudgetCapacityMismatch(BudgetResolverError): + """Raised when a W2 snapshot's W1 identity disagrees with the active W1. + + Catches the case where a W2 snapshot computed from one model's W1 + capacity is dispatched against a different model (stale cache, mid-flight + swap, cross-tenant leak). Verified at the trusted dispatch boundary as + defense-in-depth per CM-013. + """ + + def __init__(self, *, field: str, expected: str, actual: str) -> None: + self.field = field + self.expected = expected + self.actual = actual + super().__init__( + "safe_input_budget_capacity_mismatch: " + f"field={field} expected={expected} actual={actual}" + ) + + +class CapacityReservePolicy(BaseModel): + """Immutable W2 reserve policy resolved before budget calculation.""" + + model_config = ConfigDict(frozen=True) + + soft_limit_ratio: float = Field( + default=0.8, + gt=0, + le=1, + description="Ratio of hard safe input budget where proactive compaction begins.", + ) + soft_limit_ratio_source: SoftLimitRatioSource = "code_default" + approved_profile_reserve_tokens: Optional[int] = Field( + default=None, + ge=0, + description=( + "Verified reserve from the selected capability profile. When present, " + "it may replace the unified 10 percent uncertainty reserve." + ), + ) + + +class RequestBudgetOverrides(BaseModel): + """Per-request W2 budget overrides accepted from trusted backend resolution.""" + + model_config = ConfigDict(frozen=True) + + requested_output_tokens: Optional[int] = Field(default=None, gt=0) + + +class SafeInputBudgetSnapshot(BaseModel): + """Immutable W2 budget contract consumed by W3 and trusted dispatch.""" + + model_config = ConfigDict(frozen=True) + + w1_fingerprint: str + provider: str + model_name: str + + requested_output_tokens: int + output_reserve_source: OutputReserveSource + + provider_input_limit_tokens: int + uncertainty_reserve_tokens: int + uncertainty_reserve_basis: UncertaintyReserveBasis + approved_profile_reserve_tokens: Optional[int] = None + + soft_limit_ratio: float = Field(gt=0, le=1) + soft_limit_ratio_source: SoftLimitRatioSource + soft_input_budget_tokens: int + hard_input_budget_tokens: int + + field_sources: Mapping[str, str] = Field(default_factory=dict) + warnings: Sequence[str] = Field(default_factory=list) + resolver_version: str = W2_RESOLVER_VERSION + fingerprint: str + + +def compute_w2_fingerprint( + *, + w2_resolver_version: str, + w1_fingerprint: str, + provider: str, + model_name: str, + requested_output_tokens: int, + output_reserve_source: str, + uncertainty_reserve_tokens: int, + uncertainty_reserve_basis: str, + approved_profile_reserve_tokens: Optional[int], + soft_limit_ratio: float, + soft_limit_ratio_source: str, + soft_input_budget_tokens: int, + hard_input_budget_tokens: int, + field_sources: Mapping[str, str], + warnings: Sequence[str] = (), +) -> str: + """Compute the W2 ADR Decision 1 fingerprint. + + `warnings` is accepted to keep the signature aligned with the ADR, but is + intentionally excluded from the canonical payload. + """ + _ = warnings + payload: dict[str, Any] = { + "v": W2_FINGERPRINT_SCHEMA_VERSION, + "w2_resolver_version": w2_resolver_version, + "w1_fingerprint": w1_fingerprint, + "provider": provider, + "model_name": model_name, + "requested_output_tokens": requested_output_tokens, + "output_reserve_source": output_reserve_source, + "uncertainty_reserve_tokens": uncertainty_reserve_tokens, + "uncertainty_reserve_basis": uncertainty_reserve_basis, + "approved_profile_reserve_tokens": approved_profile_reserve_tokens, + "soft_limit_ratio": soft_limit_ratio, + "soft_limit_ratio_source": soft_limit_ratio_source, + "soft_input_budget_tokens": soft_input_budget_tokens, + "hard_input_budget_tokens": hard_input_budget_tokens, + "field_sources": dict(sorted(field_sources.items())), + } + encoded = json.dumps( + payload, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=True, + allow_nan=False, + ).encode("utf-8") + return hashlib.sha256(encoded).hexdigest()[:32] + + +class SafeInputBudgetCalculator: + """Pure W2 calculator over an immutable W1 capacity snapshot.""" + + _UNKNOWN_CAPABILITIES_REQUIRING_RESERVE = frozenset( + { + "capability_profile_missing", + "tokenizer", + "reasoning_window_behavior", + "provider_overhead_behavior", + } + ) + + def calculate_safe_input_budget( + self, + *, + capacity_snapshot: ModelCapacitySnapshot, + reserve_policy: CapacityReservePolicy, + request_overrides: Optional[RequestBudgetOverrides] = None, + requested_output_tokens: Optional[int] = None, + output_reserve_source: OutputReserveSource = "model_default", + ) -> SafeInputBudgetSnapshot: + effective_output_tokens = ( + requested_output_tokens + if requested_output_tokens is not None + else capacity_snapshot.requested_output_tokens + ) + effective_output_source: OutputReserveSource = output_reserve_source + if requested_output_tokens is None: + effective_output_source = "model_default" + + if effective_output_tokens <= 0: + raise InvalidReservePolicy( + "requested_output_tokens must be a positive integer" + ) + + if request_overrides and request_overrides.requested_output_tokens is not None: + if request_overrides.requested_output_tokens < effective_output_tokens: + raise InvalidReservePolicy( + "per-request requested_output_tokens may not lower the " + "resolved model or agent output reserve" + ) + effective_output_tokens = request_overrides.requested_output_tokens + effective_output_source = "request" + + if ( + capacity_snapshot.max_output_tokens is not None + and effective_output_tokens > capacity_snapshot.max_output_tokens + ): + raise RequestedOutputExceedsCapacity( + "requested_output_tokens " + f"({effective_output_tokens}) exceeds max_output_tokens " + f"({capacity_snapshot.max_output_tokens})" + ) + + provider_input_limit = self._provider_input_limit( + capacity_snapshot=capacity_snapshot, + requested_output_tokens=effective_output_tokens, + ) + + uncertainty_reserve_tokens, uncertainty_reserve_basis, warnings = ( + self._uncertainty_reserve(capacity_snapshot, reserve_policy) + ) + + if uncertainty_reserve_tokens > provider_input_limit: + raise ReserveExceedsCapacity( + "uncertainty reserve " + f"({uncertainty_reserve_tokens}) exceeds provider input limit " + f"({provider_input_limit})" + ) + + hard_input_budget_tokens = provider_input_limit - uncertainty_reserve_tokens + if hard_input_budget_tokens <= 0: + raise NoSafeInputCapacity( + "safe input budget is non-positive after applying reserves" + ) + + soft_input_budget_tokens = max( + 1, math.floor(hard_input_budget_tokens * reserve_policy.soft_limit_ratio) + ) + + field_sources = { + "requested_output_tokens": effective_output_source, + "soft_limit_ratio": reserve_policy.soft_limit_ratio_source, + "uncertainty_reserve_tokens": uncertainty_reserve_basis, + "provider_input_limit_tokens": "derived", + "hard_input_budget_tokens": "derived", + "soft_input_budget_tokens": "derived", + } + + fingerprint = compute_w2_fingerprint( + w2_resolver_version=W2_RESOLVER_VERSION, + w1_fingerprint=capacity_snapshot.fingerprint, + provider=capacity_snapshot.provider, + model_name=capacity_snapshot.model_name, + requested_output_tokens=effective_output_tokens, + output_reserve_source=effective_output_source, + uncertainty_reserve_tokens=uncertainty_reserve_tokens, + uncertainty_reserve_basis=uncertainty_reserve_basis, + approved_profile_reserve_tokens=reserve_policy.approved_profile_reserve_tokens, + soft_limit_ratio=reserve_policy.soft_limit_ratio, + soft_limit_ratio_source=reserve_policy.soft_limit_ratio_source, + soft_input_budget_tokens=soft_input_budget_tokens, + hard_input_budget_tokens=hard_input_budget_tokens, + field_sources=field_sources, + warnings=warnings, + ) + + return SafeInputBudgetSnapshot( + w1_fingerprint=capacity_snapshot.fingerprint, + provider=capacity_snapshot.provider, + model_name=capacity_snapshot.model_name, + requested_output_tokens=effective_output_tokens, + output_reserve_source=effective_output_source, + provider_input_limit_tokens=provider_input_limit, + uncertainty_reserve_tokens=uncertainty_reserve_tokens, + uncertainty_reserve_basis=uncertainty_reserve_basis, + approved_profile_reserve_tokens=reserve_policy.approved_profile_reserve_tokens, + soft_limit_ratio=reserve_policy.soft_limit_ratio, + soft_limit_ratio_source=reserve_policy.soft_limit_ratio_source, + soft_input_budget_tokens=soft_input_budget_tokens, + hard_input_budget_tokens=hard_input_budget_tokens, + field_sources=field_sources, + warnings=warnings, + resolver_version=W2_RESOLVER_VERSION, + fingerprint=fingerprint, + ) + + @staticmethod + def _provider_input_limit( + *, + capacity_snapshot: ModelCapacitySnapshot, + requested_output_tokens: int, + ) -> int: + derived_limits: list[int] = [] + if capacity_snapshot.max_input_tokens is not None: + derived_limits.append(capacity_snapshot.max_input_tokens) + if capacity_snapshot.context_window_tokens is not None: + derived_limits.append( + capacity_snapshot.context_window_tokens - requested_output_tokens + ) + if not derived_limits: + raise NoSafeInputCapacity("no provider input limit could be derived") + provider_input_limit = min(derived_limits) + if provider_input_limit <= 0: + raise NoSafeInputCapacity( + "provider input limit is non-positive after output reserve" + ) + return provider_input_limit + + def _uncertainty_reserve( + self, + capacity_snapshot: ModelCapacitySnapshot, + reserve_policy: CapacityReservePolicy, + ) -> tuple[int, UncertaintyReserveBasis, list[str]]: + unknown_required_behavior = self._UNKNOWN_CAPABILITIES_REQUIRING_RESERVE.intersection( + capacity_snapshot.unknown_capabilities + ) + + if reserve_policy.approved_profile_reserve_tokens is not None: + return ( + reserve_policy.approved_profile_reserve_tokens, + "approved_profile", + [], + ) + + if not unknown_required_behavior: + return 0, "none", [] + + if capacity_snapshot.context_window_tokens is None: + raise UncertaintyReserveBasisUnknown( + "context_window_tokens is required for the unified 10 percent " + "uncertainty reserve" + ) + + reserve = math.ceil(capacity_snapshot.context_window_tokens * 0.10) + return reserve, "context_window_10pct", ["uncertainty_reserve_active"] diff --git a/sdk/nexent/core/models/capacity_resolver.py b/sdk/nexent/core/models/capacity_resolver.py new file mode 100644 index 000000000..cb7af2e4d --- /dev/null +++ b/sdk/nexent/core/models/capacity_resolver.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +import hashlib +import json +import logging +from typing import Any, List, Literal, Mapping, Optional, Sequence, Tuple + +from pydantic import BaseModel, ConfigDict, Field + +logger = logging.getLogger("capacity_resolver") + + +RESOLVER_VERSION = "1.0.0" +FINGERPRINT_SCHEMA_VERSION = 1 + + +CountingMode = Literal["exact", "estimated"] +WindowShape = Literal["combined", "separate"] +CapacitySource = Literal[ + "operator", "profile", "provider_candidate", "legacy", "unknown" +] +ReasoningWindowBehavior = Literal["none", "reserved", "unknown"] +ProviderOverheadBehavior = Literal["negligible", "bounded", "unknown"] +PromptCacheCapability = Literal["none", "supported", "unknown"] + + +ProfileKey = Tuple[str, str] + + +class CapabilityProfile(BaseModel): + """One row in the approved provider/model capability catalog. + + Identity rules and completeness criteria are defined in + `doc/working/context-management-workstreams/W1_ADR_Capability_Catalog_Storage_and_Fingerprint.md`. + """ + + model_config = ConfigDict(frozen=True) + + provider: str = Field(description="Provider identifier (e.g. 'openai', 'dashscope', 'silicon')") + model_name: str = Field(description="Model name as used by the provider API") + capability_profile_version: str = Field( + description="Per-entry version, e.g. 'openai/gpt-4o@1'" + ) + + window_shape: WindowShape + context_window_tokens: Optional[int] = None + max_input_tokens: Optional[int] = None + max_output_tokens: Optional[int] = None + default_output_reserve_tokens: Optional[int] = None + + tokenizer_family: Optional[str] = Field( + default=None, + description=( + "Identifier resolved via `tokenizer_registry.resolve`. None forces " + "counting_mode='estimated'." + ), + ) + reasoning_window_behavior: ReasoningWindowBehavior = "unknown" + provider_overhead_behavior: ProviderOverheadBehavior = "unknown" + prompt_cache: PromptCacheCapability = "unknown" + + +class ModelCapacitySnapshot(BaseModel): + """Immutable per-request capacity resolution result. + + Consumed unchanged by W2 (safe input budget), W3 (final fit), W16 (cache + assembly), monitoring, and provider dispatch. Fingerprint is recomputed from + the contract by trusted dispatch to detect tampering or stale snapshots. + """ + + model_config = ConfigDict(frozen=True) + + model_record_id: Optional[int] = None + provider: str + model_name: str + + context_window_tokens: Optional[int] = None + max_input_tokens: Optional[int] = None + max_output_tokens: Optional[int] = None + default_output_reserve_tokens: Optional[int] = None + + requested_output_tokens: int + provider_input_limit_tokens: int + + tokenizer_family: Optional[str] = None + counting_mode: CountingMode + + unknown_capabilities: List[str] = Field(default_factory=list) + field_sources: Mapping[str, CapacitySource] = Field(default_factory=dict) + + capability_profile_version: Optional[str] = None + resolver_version: str = RESOLVER_VERSION + + warnings: List[str] = Field(default_factory=list) + fingerprint: str + + +class ResolverError(Exception): + """Base class for capacity resolution failures. + + Concrete typed failures (see ADR Decision 1 / W1 spec): + - InvalidCapacityConfiguration + - ProviderCapabilityUnknown + - UncertaintyReserveBasisUnknown + - RequestedOutputExceedsCap + - ProviderMetadataInvalid + """ + + +class InvalidCapacityConfiguration(ResolverError): + pass + + +class ProviderCapabilityUnknown(ResolverError): + pass + + +class UncertaintyReserveBasisUnknown(ResolverError): + pass + + +class RequestedOutputExceedsCap(ResolverError): + pass + + +class ProviderMetadataInvalid(ResolverError): + pass + + +def compute_fingerprint( + *, + resolver_version: str, + provider: str, + model_name: str, + context_window_tokens: Optional[int], + max_input_tokens: Optional[int], + max_output_tokens: Optional[int], + default_output_reserve_tokens: Optional[int], + requested_output_tokens: int, + provider_input_limit_tokens: int, + tokenizer_family: Optional[str], + counting_mode: CountingMode, + capability_profile_version: Optional[str], + unknown_capabilities: Sequence[str], + field_sources: Mapping[str, str], +) -> str: + """Deterministic 128-bit fingerprint of the resolved capacity contract. + + Algorithm is fixed by W1 ADR Decision 3: canonical JSON over the field set + below, SHA-256, hex-encoded, truncated to 32 chars. Any change to participating + fields or serialization requires bumping FINGERPRINT_SCHEMA_VERSION. + """ + payload: dict[str, Any] = { + "v": FINGERPRINT_SCHEMA_VERSION, + "resolver_version": resolver_version, + "provider": provider, + "model_name": model_name, + "context_window_tokens": context_window_tokens, + "max_input_tokens": max_input_tokens, + "max_output_tokens": max_output_tokens, + "default_output_reserve_tokens": default_output_reserve_tokens, + "requested_output_tokens": requested_output_tokens, + "provider_input_limit_tokens": provider_input_limit_tokens, + "tokenizer_family": tokenizer_family, + "counting_mode": counting_mode, + "capability_profile_version": capability_profile_version, + "unknown_capabilities": sorted(unknown_capabilities), + "field_sources": dict(sorted(field_sources.items())), + } + encoded = json.dumps( + payload, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=True, + allow_nan=False, + ).encode("utf-8") + return hashlib.sha256(encoded).hexdigest()[:32] + + +_OVERRIDABLE_FIELDS = ( + "context_window_tokens", + "max_input_tokens", + "max_output_tokens", + "default_output_reserve_tokens", + "tokenizer_family", +) + +# Last-resort fallback when neither the agent nor the model record sets a +# requested_output_tokens / default_output_reserve_tokens. 1024 was too small +# in practice: tool-using agents often write multi-hundred-token JSON tool +# calls plus a few hundred tokens of thought per step, and 1024 produced +# mid-JSON truncation that surfaced to users as "tool failed" instead of a +# capacity-config issue. 4096 covers the median single-turn output reliably +# without overshooting tiny-output models — those still get caught by the +# RequestedOutputExceedsCap check (capacity_resolver line 276-283 and +# the agent-edit form rule). +_DEFAULT_REQUESTED_OUTPUT_TOKENS = 4096 + + +def resolve_capacity( + *, + model_id: str, + provider: str, + operator_overrides: Optional[Mapping[str, Any]] = None, + requested_output_tokens: Optional[int] = None, + capability_profiles: Mapping[ProfileKey, CapabilityProfile], +) -> ModelCapacitySnapshot: + """Resolve capacity for one model request. + + Precedence per W1 spec: operator override > approved profile > unknown. + Production dispatch requires known hard capacity; otherwise + `ProviderCapabilityUnknown` is raised. Provider-discovery candidate metadata + is not consulted by this implementation — it is recorded by upstream provider + adapters and surfaced only after operators promote it into an approved + profile. + """ + # Lazy import to avoid a static cycle (tokenizer_registry imports CountingMode). + from . import tokenizer_registry as _tokenizer_registry + + overrides = dict(operator_overrides) if operator_overrides else {} + profile = capability_profiles.get((provider, model_id)) + + field_sources: dict[str, CapacitySource] = {} + + def _pick(field: str) -> Any: + value = overrides.get(field) + if value is not None: + field_sources[field] = "operator" + return value + if profile is not None: + profile_value = getattr(profile, field) + if profile_value is not None: + field_sources[field] = "profile" + return profile_value + field_sources[field] = "unknown" + return None + + context_window_tokens = _pick("context_window_tokens") + max_input_tokens = _pick("max_input_tokens") + max_output_tokens = _pick("max_output_tokens") + default_output_reserve_tokens = _pick("default_output_reserve_tokens") + tokenizer_family = _pick("tokenizer_family") + capability_profile_version = ( + profile.capability_profile_version if profile is not None else None + ) + + if context_window_tokens is None and max_input_tokens is None: + raise ProviderCapabilityUnknown( + f"No known hard capacity for ({provider!r}, {model_id!r}); " + f"set context_window_tokens or max_input_tokens via operator override " + f"or add a capability profile entry." + ) + + for name, value in ( + ("context_window_tokens", context_window_tokens), + ("max_input_tokens", max_input_tokens), + ("max_output_tokens", max_output_tokens), + ("default_output_reserve_tokens", default_output_reserve_tokens), + ): + if value is not None and value <= 0: + raise InvalidCapacityConfiguration( + f"{name} must be a positive integer, got {value}" + ) + + if ( + max_output_tokens is not None + and context_window_tokens is not None + and max_output_tokens > context_window_tokens + ): + raise InvalidCapacityConfiguration( + f"max_output_tokens ({max_output_tokens}) exceeds context_window_tokens " + f"({context_window_tokens})" + ) + + if ( + max_input_tokens is not None + and context_window_tokens is not None + and max_input_tokens > context_window_tokens + ): + raise InvalidCapacityConfiguration( + f"max_input_tokens ({max_input_tokens}) exceeds context_window_tokens " + f"({context_window_tokens}); operators who fill an input cap above the " + f"window will be silently clipped by the derived provider_input_limit, " + f"so the override never takes effect" + ) + + if requested_output_tokens is None: + requested_output_tokens = ( + default_output_reserve_tokens + if default_output_reserve_tokens is not None + else _DEFAULT_REQUESTED_OUTPUT_TOKENS + ) + if requested_output_tokens <= 0: + raise InvalidCapacityConfiguration( + f"requested_output_tokens must be positive, got {requested_output_tokens}" + ) + if ( + max_output_tokens is not None + and requested_output_tokens > max_output_tokens + ): + raise RequestedOutputExceedsCap( + f"requested_output_tokens ({requested_output_tokens}) exceeds " + f"max_output_tokens ({max_output_tokens})" + ) + + derived_limits: list[int] = [] + if max_input_tokens is not None: + derived_limits.append(max_input_tokens) + if context_window_tokens is not None: + derived_limits.append(context_window_tokens - requested_output_tokens) + provider_input_limit_tokens = min(derived_limits) + if provider_input_limit_tokens <= 0: + raise InvalidCapacityConfiguration( + f"derived provider_input_limit_tokens is non-positive: " + f"{provider_input_limit_tokens}" + ) + + _, counting_mode = _tokenizer_registry.resolve(tokenizer_family) + + unknown_capabilities: list[str] = [] + if profile is None: + unknown_capabilities.append("capability_profile_missing") + else: + if profile.reasoning_window_behavior == "unknown": + unknown_capabilities.append("reasoning_window_behavior") + if profile.provider_overhead_behavior == "unknown": + unknown_capabilities.append("provider_overhead_behavior") + if profile.prompt_cache == "unknown": + unknown_capabilities.append("prompt_cache") + if counting_mode == "estimated": + unknown_capabilities.append("tokenizer") + + fingerprint = compute_fingerprint( + resolver_version=RESOLVER_VERSION, + provider=provider, + model_name=model_id, + context_window_tokens=context_window_tokens, + max_input_tokens=max_input_tokens, + max_output_tokens=max_output_tokens, + default_output_reserve_tokens=default_output_reserve_tokens, + requested_output_tokens=requested_output_tokens, + provider_input_limit_tokens=provider_input_limit_tokens, + tokenizer_family=tokenizer_family, + counting_mode=counting_mode, + capability_profile_version=capability_profile_version, + unknown_capabilities=unknown_capabilities, + field_sources=dict(field_sources), + ) + + return ModelCapacitySnapshot( + provider=provider, + model_name=model_id, + context_window_tokens=context_window_tokens, + max_input_tokens=max_input_tokens, + max_output_tokens=max_output_tokens, + default_output_reserve_tokens=default_output_reserve_tokens, + requested_output_tokens=requested_output_tokens, + provider_input_limit_tokens=provider_input_limit_tokens, + tokenizer_family=tokenizer_family, + counting_mode=counting_mode, + unknown_capabilities=unknown_capabilities, + field_sources=dict(field_sources), + capability_profile_version=capability_profile_version, + resolver_version=RESOLVER_VERSION, + warnings=[], + fingerprint=fingerprint, + ) diff --git a/sdk/nexent/core/models/openai_llm.py b/sdk/nexent/core/models/openai_llm.py index a9127595c..d3b0ce518 100644 --- a/sdk/nexent/core/models/openai_llm.py +++ b/sdk/nexent/core/models/openai_llm.py @@ -18,6 +18,13 @@ from smolagents import Tool from smolagents.models import OpenAIServerModel, ChatMessage, MessageRole +from .capacity_budget import ( + CallerMaxTokensOverrideForbidden, + SafeInputBudgetCapacityMismatch, + SafeInputBudgetFingerprintMismatch, + SafeInputBudgetSnapshot, + compute_w2_fingerprint, +) from ..utils.observer import MessageObserver, ProcessType logger = logging.getLogger("openai_llm") @@ -28,7 +35,10 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, ssl_verify=True, model_factory: Optional[str] = None, display_name: Optional[str] = None, extra_body: Optional[Dict[str, Any]] = None, + max_output_tokens: Optional[int] = None, max_tokens: Optional[int] = None, + safe_input_budget_snapshot: Optional[SafeInputBudgetSnapshot | Dict[str, Any]] = None, + capacity_snapshot: Optional[Dict[str, Any]] = None, timeout_seconds: Optional[float] = None, *args, **kwargs): """ Initialize OpenAI Model with observer and SSL verification option. @@ -45,10 +55,14 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, extra_body: Optional dict merged into every chat.completions.create request body. Defaults to None so production behaviour is unchanged for callers that do not opt in. - max_tokens: Per-call completion output cap. Defaults to None so - production keeps the provider default (unbounded / - model max). Benchmarks set this explicitly (e.g. 4096) - to bound degenerate generation loops on long contexts. + max_output_tokens: Per-call completion output cap. Preferred name + per W1 ADR. Defaults to None so production keeps the + provider default (unbounded / model max). Benchmarks set + this explicitly (e.g. 4096) to bound degenerate generation + loops on long contexts. + max_tokens: DEPRECATED alias for max_output_tokens retained during + the W1 migration. If max_output_tokens is supplied it + wins; otherwise max_tokens is copied into it. *args: Additional positional arguments for OpenAIServerModel **kwargs: Additional keyword arguments for OpenAIServerModel """ @@ -60,7 +74,18 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, self.model_factory = (model_factory or "").lower() self.display_name = display_name self.extra_body = extra_body or None - self.max_tokens = max_tokens + self.safe_input_budget_snapshot = safe_input_budget_snapshot + self.capacity_snapshot = capacity_snapshot + if max_output_tokens is None and max_tokens is not None: + logger.debug( + "OpenAIModel received legacy max_tokens=%s; treating as max_output_tokens. " + "Update callers to pass max_output_tokens directly.", + max_tokens, + ) + max_output_tokens = max_tokens + self.max_output_tokens = max_output_tokens + # Legacy alias kept readable for any caller still reading .max_tokens. + self.max_tokens = max_output_tokens # Create http_client based on ssl_verify parameter and timeout if not ssl_verify or timeout_seconds is not None: @@ -92,10 +117,15 @@ def __init__(self, observer: MessageObserver = MessageObserver, temperature=0.2, _monitoring_display_name.set(self.display_name) def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List[str]] = None, - response_format: dict[str, str] | None = None, tools_to_call_from: Optional[List[Tool]] = None, _token_tracker=None, **kwargs, ) -> ChatMessage: + response_format: dict[str, str] | None = None, tools_to_call_from: Optional[List[Tool]] = None, + _token_tracker=None, safe_input_budget_snapshot: Optional[SafeInputBudgetSnapshot] = None, + **kwargs, ) -> ChatMessage: _monitoring_operation.set("chat_completion") if _token_tracker is None: + trusted_budget_snapshot = ( + safe_input_budget_snapshot or self.safe_input_budget_snapshot + ) invocation_parameters = { "temperature": self.temperature, "top_p": self.top_p, @@ -111,6 +141,9 @@ def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List else "input.value" ) trace_attributes[input_attr_key] = messages or [] + trace_attributes.update( + self._safe_input_budget_trace_attributes(trusted_budget_snapshot) + ) with self._monitoring.trace_llm_request( f"{self.display_name or self.model_id}.generate", @@ -125,6 +158,7 @@ def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List response_format=response_format, tools_to_call_from=tools_to_call_from, _token_tracker=token_tracker, + safe_input_budget_snapshot=safe_input_budget_snapshot, **kwargs, ) @@ -178,13 +212,30 @@ def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List if self.extra_body: completion_kwargs["extra_body"] = self.extra_body + trusted_budget_snapshot = ( + safe_input_budget_snapshot or self.safe_input_budget_snapshot + ) + # Bound completion length unless the caller passed their own override # via kwargs (which already landed in completion_kwargs above). - if self.max_tokens is not None and "max_tokens" not in completion_kwargs: - completion_kwargs["max_tokens"] = self.max_tokens - - current_request = self.client.chat.completions.create( - stream=True, **completion_kwargs) + # OpenAI wire field stays max_tokens; internal name is max_output_tokens. + # When a W2 snapshot is active, its requested_output_tokens is the sole + # authority per CM-030 — skip the pre-W2 auto-fill so the dispatch + # boundary does not see max_output_tokens masquerading as a caller + # override and reject it via CallerMaxTokensOverrideForbidden. + if ( + self.max_output_tokens is not None + and "max_tokens" not in completion_kwargs + and trusted_budget_snapshot is None + ): + completion_kwargs["max_tokens"] = self.max_output_tokens + + current_request = self._dispatch_chat_completion( + safe_input_budget_snapshot=trusted_budget_snapshot, + capacity_snapshot=self.capacity_snapshot, + stream=True, + **completion_kwargs, + ) # Validate response type: ensure we got a proper iterator, not error strings or dicts # Some APIs return error strings like "error: rate limit" or JSON dicts on failure @@ -327,6 +378,142 @@ def __call__(self, messages: List[Dict[str, Any]], stop_sequences: Optional[List raise ValueError(f"Token limit exceeded: {str(e)}") raise e + def _dispatch_chat_completion( + self, + *, + safe_input_budget_snapshot: Optional[SafeInputBudgetSnapshot | Dict[str, Any]] = None, + capacity_snapshot: Optional[Dict[str, Any]] = None, + **completion_kwargs: Any, + ) -> Any: + """Dispatch the OpenAI chat completion request. + + When W2 supplied a trusted safe-input-budget snapshot, this method is + the provider dispatch boundary: caller `max_tokens` overrides must + match the snapshot, and absent values are filled from the snapshot. + + When the active W1 capacity snapshot is also threaded through, the + boundary additionally verifies W1->W2 fingerprint and provider/model + identity to catch a stale or cross-model W2 snapshot before the + provider call. + """ + snapshot = self._coerce_safe_input_budget_snapshot(safe_input_budget_snapshot) + if snapshot is not None: + self._verify_w1_w2_consistency( + budget_snapshot=snapshot, + capacity_snapshot=capacity_snapshot, + ) + trusted_max_tokens = snapshot.requested_output_tokens + caller_max_tokens = completion_kwargs.get("max_tokens") + if caller_max_tokens is not None and caller_max_tokens != trusted_max_tokens: + raise CallerMaxTokensOverrideForbidden( + snapshot_value=trusted_max_tokens, + caller_value=caller_max_tokens, + ) + completion_kwargs["max_tokens"] = trusted_max_tokens + return self.client.chat.completions.create(**completion_kwargs) + + @staticmethod + def _verify_w1_w2_consistency( + *, + budget_snapshot: SafeInputBudgetSnapshot, + capacity_snapshot: Optional[Dict[str, Any]], + ) -> None: + """Reject a W2 snapshot whose W1 identity disagrees with the active W1. + + Defense-in-depth per CM-013: a W2 snapshot computed from a different + model's W1 capacity (model swap mid-flight, stale cache, cross-tenant + leak) must not be allowed through dispatch even if its own fingerprint + self-checks. + + When the active W1 capacity_snapshot is not threaded through, the + check is skipped. This preserves the migration window for legacy + rows without capacity columns, where W2 already does not produce a + snapshot. + """ + if not capacity_snapshot: + return + w1_fingerprint = capacity_snapshot.get("capacity_fingerprint") + provider = capacity_snapshot.get("provider") + model_name = capacity_snapshot.get("model_name") + if not w1_fingerprint and not provider and not model_name: + return + if w1_fingerprint and w1_fingerprint != budget_snapshot.w1_fingerprint: + raise SafeInputBudgetCapacityMismatch( + field="w1_fingerprint", + expected=w1_fingerprint, + actual=budget_snapshot.w1_fingerprint, + ) + if provider and provider != budget_snapshot.provider: + raise SafeInputBudgetCapacityMismatch( + field="provider", + expected=provider, + actual=budget_snapshot.provider, + ) + if model_name and model_name != budget_snapshot.model_name: + raise SafeInputBudgetCapacityMismatch( + field="model_name", + expected=model_name, + actual=budget_snapshot.model_name, + ) + + @staticmethod + def _coerce_safe_input_budget_snapshot( + snapshot: Optional[SafeInputBudgetSnapshot | Dict[str, Any]], + ) -> Optional[SafeInputBudgetSnapshot]: + if snapshot is None: + return None + if isinstance(snapshot, SafeInputBudgetSnapshot): + resolved = snapshot + elif isinstance(snapshot, dict): + resolved = SafeInputBudgetSnapshot.model_validate(snapshot) + else: + raise TypeError( + "safe_input_budget_snapshot must be a SafeInputBudgetSnapshot or dict" + ) + expected = compute_w2_fingerprint( + w2_resolver_version=resolved.resolver_version, + w1_fingerprint=resolved.w1_fingerprint, + provider=resolved.provider, + model_name=resolved.model_name, + requested_output_tokens=resolved.requested_output_tokens, + output_reserve_source=resolved.output_reserve_source, + uncertainty_reserve_tokens=resolved.uncertainty_reserve_tokens, + uncertainty_reserve_basis=resolved.uncertainty_reserve_basis, + approved_profile_reserve_tokens=resolved.approved_profile_reserve_tokens, + soft_limit_ratio=resolved.soft_limit_ratio, + soft_limit_ratio_source=resolved.soft_limit_ratio_source, + soft_input_budget_tokens=resolved.soft_input_budget_tokens, + hard_input_budget_tokens=resolved.hard_input_budget_tokens, + field_sources=resolved.field_sources, + warnings=resolved.warnings, + ) + if resolved.fingerprint != expected: + raise SafeInputBudgetFingerprintMismatch( + expected=expected, + actual=resolved.fingerprint, + ) + return resolved + + @classmethod + def _safe_input_budget_trace_attributes( + cls, + snapshot: Optional[SafeInputBudgetSnapshot | Dict[str, Any]], + ) -> Dict[str, Any]: + snapshot = cls._coerce_safe_input_budget_snapshot(snapshot) + if snapshot is None: + return {} + return { + "w2.budget_fingerprint": snapshot.fingerprint, + "w2.w1_fingerprint": snapshot.w1_fingerprint, + "w2.requested_output_tokens": snapshot.requested_output_tokens, + "w2.output_reserve_source": snapshot.output_reserve_source, + "w2.provider_input_limit_tokens": snapshot.provider_input_limit_tokens, + "w2.soft_input_budget_tokens": snapshot.soft_input_budget_tokens, + "w2.hard_input_budget_tokens": snapshot.hard_input_budget_tokens, + "w2.uncertainty_reserve_tokens": snapshot.uncertainty_reserve_tokens, + "w2.uncertainty_reserve_basis": snapshot.uncertainty_reserve_basis, + } + async def check_connectivity(self) -> bool: """ Test if the connection to the remote OpenAI large model service is normal diff --git a/sdk/nexent/core/models/tokenizer_registry.py b/sdk/nexent/core/models/tokenizer_registry.py new file mode 100644 index 000000000..6a8f7d2e9 --- /dev/null +++ b/sdk/nexent/core/models/tokenizer_registry.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import json +import logging +import re +from typing import Dict, Optional, Protocol, Sequence, Tuple, runtime_checkable + +from .capacity_resolver import CountingMode + +logger = logging.getLogger("tokenizer_registry") + + +TOKENIZER_FAMILY_PATTERN = re.compile(r"^[a-z][a-z0-9_.]{0,49}$") + + +def is_valid_family_identifier(family: str) -> bool: + """Validate against the naming convention fixed by W1 ADR Decision 1.""" + return bool(TOKENIZER_FAMILY_PATTERN.match(family)) + + +@runtime_checkable +class TokenizerAdapter(Protocol): + """Contract for a tokenizer-family counting implementation. + + Implementations must be deterministic, side-effect free, and threadsafe. + Promotion from `estimated` to `exact` requires meeting the accuracy gate + defined in W1 ADR Decision 1 (>=100-message fixture, MAE <= 0.5%, max single + error <= 2%). + """ + + family: str + + def count_tokens(self, messages: Sequence[dict]) -> int: ... + + +class FallbackEstimator: + """Generic character-to-token estimator used when no family adapter matches. + + Never marked `exact`. Purpose: avoid hard failures when a catalog entry has + an unknown tokenizer family — operators always see a budget number, just one + that triggers W2's 10% uncertainty reserve. + """ + + family = "_fallback" + + def count_tokens(self, messages: Sequence[dict]) -> int: + encoded = json.dumps(list(messages), ensure_ascii=False) + return max(1, len(encoded) // 4) + + +FALLBACK: TokenizerAdapter = FallbackEstimator() + + +REGISTRY: Dict[str, TokenizerAdapter] = {} + + +def register(adapter: TokenizerAdapter) -> None: + """Register a verified adapter. Called once at import time by adapter modules.""" + family = adapter.family + if not is_valid_family_identifier(family): + raise ValueError( + f"Tokenizer family {family!r} does not match required pattern " + f"{TOKENIZER_FAMILY_PATTERN.pattern}" + ) + if family in REGISTRY: + raise ValueError(f"Tokenizer family {family!r} is already registered") + REGISTRY[family] = adapter + + +def resolve(family: Optional[str]) -> Tuple[TokenizerAdapter, CountingMode]: + """Return (adapter, counting_mode) for the requested tokenizer family. + + Returns FALLBACK with `estimated` when family is None or unmapped. Returns + the registered adapter with `exact` when a verified mapping exists. + """ + if family is None or family not in REGISTRY: + return FALLBACK, "estimated" + return REGISTRY[family], "exact" diff --git a/sdk/nexent/monitor/__init__.py b/sdk/nexent/monitor/__init__.py index 5fc6406df..c1af5e72e 100644 --- a/sdk/nexent/monitor/__init__.py +++ b/sdk/nexent/monitor/__init__.py @@ -20,6 +20,10 @@ is_opentelemetry_available, set_monitoring_context, get_monitoring_context, + set_monitoring_capacity_snapshot, + get_monitoring_capacity_snapshot, + set_monitoring_safe_input_budget_snapshot, + get_monitoring_safe_input_budget_snapshot, set_agent_monitoring_context, get_agent_monitoring_context, agent_monitoring_context, @@ -53,6 +57,10 @@ 'is_opentelemetry_available', 'set_monitoring_context', 'get_monitoring_context', + 'set_monitoring_capacity_snapshot', + 'get_monitoring_capacity_snapshot', + 'set_monitoring_safe_input_budget_snapshot', + 'get_monitoring_safe_input_budget_snapshot', 'set_agent_monitoring_context', 'get_agent_monitoring_context', 'agent_monitoring_context', diff --git a/sdk/nexent/monitor/monitoring.py b/sdk/nexent/monitor/monitoring.py index ebe442901..b3bef9cd0 100644 --- a/sdk/nexent/monitor/monitoring.py +++ b/sdk/nexent/monitor/monitoring.py @@ -72,6 +72,10 @@ # display_name carried from model instance to client-level monitoring wrapper _monitoring_display_name: ContextVar[Optional[str]] = ContextVar( "_monitoring_display_name", default=None) +_monitoring_capacity_snapshot: ContextVar[Optional[Dict[str, Any]]] = ContextVar( + "_monitoring_capacity_snapshot", default=None) +_monitoring_safe_input_budget_snapshot: ContextVar[Optional[Dict[str, Any]]] = ContextVar( + "_monitoring_safe_input_budget_snapshot", default=None) def set_monitoring_context( @@ -111,6 +115,26 @@ def get_monitoring_context() -> Dict[str, Any]: } +def set_monitoring_capacity_snapshot(snapshot: Optional[Dict[str, Any]]) -> None: + """Bind resolved model capacity metadata for the current request scope.""" + _monitoring_capacity_snapshot.set(snapshot) + + +def get_monitoring_capacity_snapshot() -> Optional[Dict[str, Any]]: + """Return the resolved capacity metadata bound to the current request.""" + return _monitoring_capacity_snapshot.get() + + +def set_monitoring_safe_input_budget_snapshot(snapshot: Optional[Dict[str, Any]]) -> None: + """Bind resolved W2 safe-input budget metadata for the current request.""" + _monitoring_safe_input_budget_snapshot.set(snapshot) + + +def get_monitoring_safe_input_budget_snapshot() -> Optional[Dict[str, Any]]: + """Return the resolved W2 safe-input budget metadata bound to the current request.""" + return _monitoring_safe_input_budget_snapshot.get() + + F = TypeVar('F', bound=Callable[..., Any]) DEFAULT_OTLP_ENDPOINT = "http://localhost:4318" @@ -1901,6 +1925,121 @@ def _detect_model_type(model_instance: Any) -> str: return "llm" +_CAPACITY_MONITORING_FIELDS = ( + "context_window_tokens", + "default_output_reserve_tokens", + "capability_profile_version", + "capacity_source", + "requested_output_tokens", + "provider_input_limit_tokens", + "tokenizer_family", + "counting_mode", + "unknown_capabilities", + "capacity_fingerprint", +) + + +def _dominant_capacity_source(field_sources: Any) -> Optional[str]: + if not isinstance(field_sources, dict) or not field_sources: + return None + values = [value for value in field_sources.values() if value] + if not values: + return None + for preferred in ("operator", "profile", "provider_candidate", "legacy", "unknown"): + if preferred in values: + return preferred + return str(values[0]) + + +def _normalize_capacity_snapshot(snapshot: Any) -> Dict[str, Any]: + if snapshot is None: + return {} + if hasattr(snapshot, "model_dump"): + snapshot = snapshot.model_dump() + if not isinstance(snapshot, dict): + return {} + + normalized = { + "context_window_tokens": snapshot.get("context_window_tokens"), + "default_output_reserve_tokens": snapshot.get("default_output_reserve_tokens"), + "capability_profile_version": snapshot.get("capability_profile_version"), + "capacity_source": snapshot.get("capacity_source") + or _dominant_capacity_source(snapshot.get("field_sources")), + "requested_output_tokens": snapshot.get("requested_output_tokens"), + "provider_input_limit_tokens": snapshot.get("provider_input_limit_tokens"), + "tokenizer_family": snapshot.get("tokenizer_family"), + "counting_mode": snapshot.get("counting_mode"), + "unknown_capabilities": snapshot.get("unknown_capabilities"), + "capacity_fingerprint": snapshot.get("capacity_fingerprint") + or snapshot.get("fingerprint"), + } + return { + key: value + for key, value in normalized.items() + if key in _CAPACITY_MONITORING_FIELDS and value is not None + } + + +def _enrich_record_with_capacity_snapshot(record: Dict[str, Any]) -> None: + capacity_fields = _normalize_capacity_snapshot(get_monitoring_capacity_snapshot()) + if capacity_fields: + record.update(capacity_fields) + + +_BUDGET_MONITORING_FIELDS = frozenset( + { + "budget_fingerprint", + "budget_w1_fingerprint", + "budget_requested_output_tokens", + "budget_output_reserve_source", + "budget_provider_input_limit_tokens", + "budget_uncertainty_reserve_tokens", + "budget_uncertainty_reserve_basis", + "budget_soft_limit_ratio", + "budget_soft_input_budget_tokens", + "budget_hard_input_budget_tokens", + "budget_warnings", + } +) + + +def _normalize_safe_input_budget_snapshot(snapshot: Any) -> Dict[str, Any]: + if snapshot is None: + return {} + if hasattr(snapshot, "model_dump"): + snapshot = snapshot.model_dump() + if not isinstance(snapshot, dict): + return {} + + normalized = { + "budget_fingerprint": snapshot.get("fingerprint") + or snapshot.get("budget_fingerprint"), + "budget_w1_fingerprint": snapshot.get("w1_fingerprint"), + "budget_requested_output_tokens": snapshot.get("requested_output_tokens"), + "budget_output_reserve_source": snapshot.get("output_reserve_source"), + "budget_provider_input_limit_tokens": snapshot.get("provider_input_limit_tokens"), + "budget_uncertainty_reserve_tokens": snapshot.get("uncertainty_reserve_tokens"), + "budget_uncertainty_reserve_basis": snapshot.get("uncertainty_reserve_basis"), + "budget_soft_limit_ratio": snapshot.get("soft_limit_ratio"), + "budget_soft_input_budget_tokens": snapshot.get("soft_input_budget_tokens"), + "budget_hard_input_budget_tokens": snapshot.get("hard_input_budget_tokens"), + "budget_warnings": snapshot.get("warnings"), + } + return { + key: value + for key, value in normalized.items() + if key in _BUDGET_MONITORING_FIELDS and value is not None + } + + +def _enrich_record_with_safe_input_budget_snapshot(record: Dict[str, Any]) -> None: + budget_fields = _normalize_safe_input_budget_snapshot( + get_monitoring_safe_input_budget_snapshot() + ) + if budget_fields: + record.update(budget_fields) + + def record_model_call( model_type: str, model_name: str, @@ -1983,6 +2122,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self.display_name: record["display_name"] = self.display_name + _enrich_record_with_capacity_snapshot(record) + _enrich_record_with_safe_input_budget_snapshot(record) + buffer = get_monitoring_buffer() if buffer and buffer.is_enabled: buffer.add_record(record) @@ -2211,6 +2353,9 @@ def _enqueue_client_monitoring_record( if display_name: record["display_name"] = display_name + _enrich_record_with_capacity_snapshot(record) + _enrich_record_with_safe_input_budget_snapshot(record) + buffer.add_record(record) except Exception: pass @@ -2296,6 +2441,9 @@ def _enrich_record_with_context(record, tracker, kwargs): if display_name: record["display_name"] = display_name + _enrich_record_with_capacity_snapshot(record) + _enrich_record_with_safe_input_budget_snapshot(record) + return tenant_id @@ -2537,6 +2685,10 @@ async def my_function(): 'is_opentelemetry_available', 'set_monitoring_context', 'get_monitoring_context', + 'set_monitoring_capacity_snapshot', + 'get_monitoring_capacity_snapshot', + 'set_monitoring_safe_input_budget_snapshot', + 'get_monitoring_safe_input_budget_snapshot', 'set_agent_monitoring_context', 'get_agent_monitoring_context', 'agent_monitoring_context', diff --git a/test/backend/agents/test_create_agent_info.py b/test/backend/agents/test_create_agent_info.py index 6d7fef775..2aa6f14d3 100644 --- a/test/backend/agents/test_create_agent_info.py +++ b/test/backend/agents/test_create_agent_info.py @@ -63,6 +63,10 @@ class MockToolParamsRequest(BaseModel): consts_model_module.AgentToolParamsRequest = MockAgentToolParamsRequest consts_model_module.ToolParamsRequest = MockToolParamsRequest sys.modules["consts.model"] = consts_model_module +sys.modules["consts.capability_profiles"] = types.ModuleType( + "consts.capability_profiles" +) +sys.modules["consts.capability_profiles"].CATALOG = {} # Mock consts.exceptions module with ValidationError consts_exceptions_module = types.ModuleType("consts.exceptions") @@ -77,6 +81,11 @@ class MockToolParamsRequest(BaseModel): if consts_module: setattr(consts_module, "model", consts_model_module) setattr(consts_module, "exceptions", consts_exceptions_module) + setattr( + consts_module, + "capability_profiles", + sys.modules["consts.capability_profiles"], + ) # Also add model to consts module attributes (with AgentToolParamsRequest and ToolParamsRequest) consts_module = sys.modules.get("consts") @@ -249,6 +258,88 @@ def model_validate(cls, value): sys.modules['nexent.core'] = _create_stub_module("nexent.core") sys.modules['nexent.core.agents'] = _create_stub_module("nexent.core.agents") sys.modules['nexent.core.utils'] = _create_stub_module("nexent.core.utils") +sys.modules['nexent.core.models'] = _create_stub_module("nexent.core.models") + + +class MockProviderCapabilityUnknown(Exception): + pass + + +class MockResolverError(Exception): + pass + + +class MockModelCapacitySnapshot: + def __init__(self, **kwargs): + self.provider = kwargs.get("provider", "test") + self.model_name = kwargs.get("model_name", "test-model") + self.context_window_tokens = kwargs.get("context_window_tokens", 32768) + self.default_output_reserve_tokens = kwargs.get( + "default_output_reserve_tokens", + 4096, + ) + self.capability_profile_version = kwargs.get("capability_profile_version") + self.field_sources = kwargs.get("field_sources", {}) + self.requested_output_tokens = kwargs.get("requested_output_tokens") + self.provider_input_limit_tokens = kwargs.get( + "provider_input_limit_tokens", + 28672, + ) + self.tokenizer_family = kwargs.get("tokenizer_family") + self.counting_mode = kwargs.get("counting_mode", "estimated") + self.unknown_capabilities = kwargs.get("unknown_capabilities", []) + self.fingerprint = kwargs.get("fingerprint", "test-fingerprint") + + def model_dump(self): + return self.__dict__.copy() + + +class MockRequestBudgetOverrides: + def __init__(self, requested_output_tokens=None): + self.requested_output_tokens = requested_output_tokens + + +class MockSafeInputBudgetSnapshot: + def __init__(self, capacity_snapshot, requested_output_tokens=None): + self.model_name = capacity_snapshot.model_name + self.requested_output_tokens = requested_output_tokens or 4096 + self.soft_input_budget_tokens = 24576 + self.hard_input_budget_tokens = 28672 + self.fingerprint = "safe-budget-fingerprint" + self.warnings = [] + + def model_dump(self): + return self.__dict__.copy() + + +class MockSafeInputBudgetCalculator: + def calculate_safe_input_budget( + self, + capacity_snapshot, + reserve_policy=None, + request_overrides=None, + requested_output_tokens=None, + output_reserve_source="model_default", + ): + override_tokens = getattr(request_overrides, "requested_output_tokens", None) + return MockSafeInputBudgetSnapshot( + capacity_snapshot, + requested_output_tokens=override_tokens or requested_output_tokens, + ) + + +sys.modules['nexent.core.models.capacity_resolver'] = _create_stub_module( + "nexent.core.models.capacity_resolver", + ModelCapacitySnapshot=MockModelCapacitySnapshot, + ProviderCapabilityUnknown=MockProviderCapabilityUnknown, + ResolverError=MockResolverError, + resolve_capacity=MagicMock(return_value=MockModelCapacitySnapshot()), +) +sys.modules['nexent.core.models.capacity_budget'] = _create_stub_module( + "nexent.core.models.capacity_budget", + RequestBudgetOverrides=MockRequestBudgetOverrides, + SafeInputBudgetCalculator=MockSafeInputBudgetCalculator, +) # Create mock classes that might be imported mock_agent_config = MagicMock() @@ -1676,12 +1767,15 @@ async def test_create_agent_config_basic(self): prompt_templates={"system_prompt": "populated_system_prompt"}, tools=ANY, max_steps=5, + requested_output_tokens=None, model_name="test_model", provide_run_summary=True, managed_agents=[], external_a2a_agents=[], context_manager_config=ANY, context_components=ANY, + capacity_snapshot=ANY, + safe_input_budget_snapshot=ANY, verification_config=ANY ) @@ -1748,12 +1842,15 @@ async def test_create_agent_config_with_sub_agents(self): "system_prompt": "populated_system_prompt"}, tools=ANY, max_steps=5, + requested_output_tokens=None, model_name="test_model", provide_run_summary=True, managed_agents=[mock_sub_agent_config], external_a2a_agents=[], context_manager_config=ANY, context_components=ANY, + capacity_snapshot=ANY, + safe_input_budget_snapshot=ANY, verification_config=ANY ) @@ -2007,12 +2104,15 @@ async def test_create_agent_config_model_id_none(self): prompt_templates={"system_prompt": "populated_system_prompt"}, tools=ANY, max_steps=5, + requested_output_tokens=None, model_name="main_model", provide_run_summary=True, managed_agents=[], external_a2a_agents=[], context_manager_config=ANY, context_components=ANY, + capacity_snapshot=None, + safe_input_budget_snapshot=None, verification_config=ANY ) @@ -3144,7 +3244,9 @@ async def test_create_agent_run_info_success(self): "transport": "streamable-http" }], history=[], - stop_event="stop_event" + stop_event="stop_event", + capacity_snapshot=None, + safe_input_budget_snapshot=None ) # Verify that other functions were called correctly diff --git a/test/backend/app/test_model_managment_app.py b/test/backend/app/test_model_managment_app.py index ade705667..cbdc04c15 100644 --- a/test/backend/app/test_model_managment_app.py +++ b/test/backend/app/test_model_managment_app.py @@ -82,6 +82,194 @@ def sample_model_data(): } +@pytest.mark.asyncio +async def test_suggest_capacity_success(client, auth_header, user_credentials, mocker): + """Test standalone capacity suggestion endpoint.""" + from backend.consts.model import CapacitySuggestionFields, ModelCapacitySuggestionResponse + + mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials) + mock_suggest = mocker.patch( + 'backend.apps.model_managment_app._suggest_capacity_for_request', + return_value=ModelCapacitySuggestionResponse( + suggestions=CapacitySuggestionFields( + context_window_tokens=128000, + max_output_tokens=16384, + default_output_reserve_tokens=4096, + tokenizer_family="o200k_base", + ), + match_kind="catalog_exact", + match_confidence="high", + match_explanation="Matched approved catalog profile openai/gpt-4o@1", + suggested_provider="openai", + canonical_model_name="gpt-4o", + capability_profile_version="openai/gpt-4o@1", + capacity_source_on_accept="operator", + ) + ) + + response = client.post( + "/model/suggest-capacity", + json={ + "model_name": "gpt-4o", + "base_url": "https://api.openai.com/v1", + "model_type": "llm", + }, + headers=auth_header, + ) + + assert response.status_code == HTTPStatus.OK + body = response.json() + # Response uses the shared {message, data} envelope so the frontend + # service layer can unwrap /model/* responses uniformly. See + # suggest_model_capacity for the rationale. + assert body["message"] == "Successfully suggested model capacity" + data = body["data"] + assert data["match_kind"] == "catalog_exact" + assert data["suggestions"]["context_window_tokens"] == 128000 + assert data["suggested_provider"] == "openai" + mock_suggest.assert_called_once() + + +@pytest.mark.asyncio +async def test_suggest_capacity_real_serialization_uses_envelope(client, auth_header, user_credentials, mocker): + """End-to-end serialization test: hit /model/suggest-capacity without + mocking the catalog matcher, so the response goes through the real + Pydantic serializer and JSONResponse envelope. Asserts the {message, + data} envelope shape and the nested catalog match. This is the safety + net for wire-format drift -- the headline W11 V1 bug shipped past + every existing unit test because nothing exercised the real + backend-to-wire format. + """ + mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials) + + response = client.post( + "/model/suggest-capacity", + json={ + "model_name": "gpt-4o", + "base_url": "https://api.openai.com/v1", + "model_type": "llm", + }, + headers=auth_header, + ) + + assert response.status_code == HTTPStatus.OK + body = response.json() + # Envelope must be present at the top level. This is the contract the + # frontend modelService reads (`result.data`); breaking it makes both + # the suggestion alert and the coverage banner dead end-to-end without + # any unit test catching it. + assert isinstance(body, dict) + assert set(body.keys()) >= {"message", "data"} + assert body["message"] == "Successfully suggested model capacity" + + data = body["data"] + assert data["match_kind"] == "catalog_exact" + assert data["match_confidence"] == "high" + assert data["suggested_provider"] == "openai" + assert data["canonical_model_name"] == "gpt-4o" + assert data["capability_profile_version"] == "openai/gpt-4o@1" + assert data["capacity_source_on_accept"] == "operator" + # Nested capacity dict is also envelope-free at this level: it sits + # directly under data.suggestions, mirroring the snake_case wire format + # that mapCapacitySuggestionFromApi expects. + assert data["suggestions"]["context_window_tokens"] > 0 + assert data["suggestions"]["max_output_tokens"] > 0 + + +@pytest.mark.asyncio +async def test_capacity_coverage_real_serialization_uses_envelope(client, auth_header, user_credentials, mocker): + """End-to-end serialization test for /model/capacity-coverage. Mocks the + service layer but lets the route serialize a real dict through + JSONResponse so the envelope contract is enforced at the wire boundary. + """ + mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials) + mocker.patch( + 'backend.apps.model_managment_app.get_capacity_coverage', + return_value={ + "total_llm_vlm": 3, + "bare_count": 1, + "bare_models": [ + { + "model_id": 99, + "model_name": "glm-5", + "model_factory": "OpenAI-API-Compatible", + "model_type": "llm", + "max_tokens": 131072, + "suggestion_available": False, + } + ], + }, + ) + + response = client.get("/model/capacity-coverage", headers=auth_header) + + assert response.status_code == HTTPStatus.OK + body = response.json() + assert isinstance(body, dict) + assert set(body.keys()) >= {"message", "data"} + assert body["message"] == "Successfully retrieved model capacity coverage" + + data = body["data"] + assert data["total_llm_vlm"] == 3 + assert data["bare_count"] == 1 + assert data["bare_models"][0]["model_id"] == 99 + assert data["bare_models"][0]["suggestion_available"] is False + + +@pytest.mark.asyncio +async def test_suggest_capacity_bad_request(client, auth_header, user_credentials, mocker): + """Test standalone capacity suggestion endpoint maps invalid input to 400.""" + mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials) + mocker.patch( + 'backend.apps.model_managment_app._suggest_capacity_for_request', + side_effect=ValueError("model_name is required"), + ) + + response = client.post( + "/model/suggest-capacity", + json={"model_name": "gpt-4o"}, + headers=auth_header, + ) + + assert response.status_code == HTTPStatus.BAD_REQUEST + assert "model_name is required" in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_capacity_coverage_success(client, auth_header, user_credentials, mocker): + """Test capacity coverage endpoint uses current tenant.""" + mocker.patch('backend.apps.model_managment_app.get_current_user_id', return_value=user_credentials) + mock_coverage = mocker.patch( + 'backend.apps.model_managment_app.get_capacity_coverage', + return_value={ + "total_llm_vlm": 2, + "bare_count": 1, + "bare_models": [ + { + "model_id": 11, + "model_name": "gpt-4o", + "model_factory": "openai", + "model_type": "llm", + "max_tokens": 16384, + "suggestion_available": True, + } + ], + }, + ) + + response = client.get("/model/capacity-coverage", headers=auth_header) + + assert response.status_code == HTTPStatus.OK + body = response.json() + assert body["message"] == "Successfully retrieved model capacity coverage" + data = body["data"] + assert data["total_llm_vlm"] == 2 + assert data["bare_count"] == 1 + assert data["bare_models"][0]["max_tokens"] == 16384 + assert data["bare_models"][0]["suggestion_available"] is True + mock_coverage.assert_called_once_with(user_credentials[1]) + + # Tests for /model/create endpoint @pytest.mark.asyncio async def test_create_model_success(client, auth_header, user_credentials, sample_model_data, mocker): @@ -443,6 +631,13 @@ async def test_verify_model_config_success(client, auth_header, sample_model_dat 'backend.apps.model_managment_app.verify_model_config_connectivity', return_value={"connectivity": True, "model_name": "gpt-4"} ) + mock_suggest = mocker.patch( + 'backend.apps.model_managment_app._capacity_suggestion_for_model_request', + return_value={ + "suggestions": {"context_window_tokens": 128000}, + "match_kind": "catalog_exact", + }, + ) response = client.post( "/model/temporary_healthcheck", json=sample_model_data) @@ -451,9 +646,11 @@ async def test_verify_model_config_success(client, auth_header, sample_model_dat data = response.json() assert data["message"] == "Successfully verified model connectivity" assert data["data"]["connectivity"] is True + assert data["data"]["capacity_suggestion"]["match_kind"] == "catalog_exact" # Success case should not have error field in response assert "error" not in data["data"] mock_verify.assert_called_once() + mock_suggest.assert_called_once() @pytest.mark.asyncio @@ -467,6 +664,7 @@ async def test_verify_model_config_failure_with_error(client, auth_header, sampl "error": "Failed to connect to model 'gpt-4' at https://api.openai.com. Please verify the URL, API key, and network connection." } ) + mock_suggest = mocker.patch('backend.apps.model_managment_app._capacity_suggestion_for_model_request') response = client.post( "/model/temporary_healthcheck", json=sample_model_data) @@ -477,9 +675,11 @@ async def test_verify_model_config_failure_with_error(client, auth_header, sampl assert data["data"]["connectivity"] is False # Failure case should have error field with descriptive message assert "error" in data["data"] + assert data["data"]["capacity_suggestion"] is None assert "Failed to connect to model" in data["data"]["error"] assert "Please verify the URL, API key, and network connection" in data["data"]["error"] mock_verify.assert_called_once() + mock_suggest.assert_not_called() @pytest.mark.asyncio diff --git a/test/backend/database/test_agent_db.py b/test/backend/database/test_agent_db.py index 77a1d82a9..ee6605f89 100644 --- a/test/backend/database/test_agent_db.py +++ b/test/backend/database/test_agent_db.py @@ -132,6 +132,7 @@ def __init__(self): self.group_ids = None self.is_new = True self.enable_context_manager = False + self.requested_output_tokens = None self.verification_config = None self.greeting_message = None self.example_questions = None @@ -436,6 +437,36 @@ def test_update_agent_skips_none_and_converts_group_ids(monkeypatch, mock_sessio agent_db_module.convert_list_to_string.assert_called_once_with([1, 2]) assert mock_agent.updated_by == "user1" +def test_update_agent_allows_explicit_requested_output_tokens_null(monkeypatch, mock_session): + """Explicit requested_output_tokens=None should clear the W2 agent override.""" + session, query = mock_session + mock_agent = MockAgent() + mock_agent.requested_output_tokens = 2048 + + mock_first = MagicMock() + mock_first.return_value = mock_agent + mock_filter = MagicMock() + mock_filter.first = mock_first + query.filter.return_value = mock_filter + + mock_ctx = MagicMock() + mock_ctx.__enter__.return_value = session + mock_ctx.__exit__.return_value = None + monkeypatch.setattr("backend.database.agent_db.get_db_session", lambda: mock_ctx) + monkeypatch.setattr("backend.database.agent_db.filter_property", lambda data, model: data) + + class AgentInfoUpdate: + def __init__(self): + self.requested_output_tokens = None + self.model_fields_set = {"requested_output_tokens"} + + agent_info = AgentInfoUpdate() + + update_agent(1, agent_info, "user1") + + assert mock_agent.requested_output_tokens is None + assert mock_agent.updated_by == "user1" + def test_update_agent_not_found(monkeypatch, mock_session): """测试更新不存在的agent""" session, query = mock_session diff --git a/test/backend/services/providers/test_dashscope_provider.py b/test/backend/services/providers/test_dashscope_provider.py index 5c6267040..fd7a24ff0 100644 --- a/test/backend/services/providers/test_dashscope_provider.py +++ b/test/backend/services/providers/test_dashscope_provider.py @@ -89,6 +89,44 @@ async def test_get_models_llm_success(self, mocker: MockFixture): assert result[0]["model_type"] == "llm" assert result[0]["model_tag"] == "chat" assert result[0]["max_tokens"] == 4096 + assert "capacity_source" not in result[0] + + @pytest.mark.asyncio + async def test_get_models_llm_surfaces_capacity_hints(self, mocker: MockFixture): + """Provider token metadata is returned as advisory capacity hints.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "output": { + "models": [ + { + "model": "qwen-plus", + "description": "Advanced text generation", + "inference_metadata": { + "request_modality": ["Text"], + "response_modality": ["Text"], + "context_length": 131072, + "max_output_tokens": "8192", + "tokenizer_family": "qwen", + } + } + ] + } + } + mock_response.raise_for_status = MagicMock() + + self._setup_mock_client(mocker, mock_response) + + provider = DashScopeModelProvider() + result = await provider.get_models({ + "model_type": "llm", + "api_key": "test-api-key", + }) + + assert result[0]["context_window_tokens"] == 131072 + assert result[0]["max_output_tokens"] == 8192 + assert result[0]["tokenizer_family"] == "qwen" + assert result[0]["capacity_source"] == "provider_candidate" @pytest.mark.asyncio async def test_get_models_embedding_success(self, mocker: MockFixture): diff --git a/test/backend/services/providers/test_modelengine_provider.py b/test/backend/services/providers/test_modelengine_provider.py index 54a3f2957..b5595df3a 100644 --- a/test/backend/services/providers/test_modelengine_provider.py +++ b/test/backend/services/providers/test_modelengine_provider.py @@ -69,6 +69,56 @@ async def test_get_models_success_with_all_types(self, mocker: MockFixture): assert result[0]["model_type"] == "llm" assert result[0]["model_tag"] == "chat" assert result[0]["max_tokens"] > 0 # LLM type should have max_tokens + assert "capacity_source" not in result[0] + + @pytest.mark.asyncio + async def test_get_models_surfaces_capacity_hints(self, mocker: MockFixture): + """Provider token metadata is returned as advisory capacity hints.""" + mock_response_data = { + "data": [ + { + "id": "llm-model-1", + "type": "chat", + "context_window_tokens": 65536, + "max_input_tokens": "60000", + "max_output_tokens": 4096, + "tokenizer_type": "deepseek", + } + ] + } + + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value=mock_response_data) + + mock_get_cm = MagicMock() + mock_get_cm.__aenter__ = AsyncMock(return_value=mock_response) + mock_get_cm.__aexit__ = AsyncMock(return_value=None) + + mock_session_instance = MagicMock() + mock_session_instance.get = MagicMock(return_value=mock_get_cm) + + mock_session_cm = MagicMock() + mock_session_cm.__aenter__ = AsyncMock(return_value=mock_session_instance) + mock_session_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.modelengine_provider.aiohttp.ClientSession", + return_value=mock_session_cm + ) + + provider = ModelEngineProvider() + result = await provider.get_models({ + "model_type": "llm", + "base_url": "https://test.example.com", + "api_key": "test-api-key", + }) + + assert result[0]["context_window_tokens"] == 65536 + assert result[0]["max_input_tokens"] == 60000 + assert result[0]["max_output_tokens"] == 4096 + assert result[0]["tokenizer_family"] == "deepseek" + assert result[0]["capacity_source"] == "provider_candidate" @pytest.mark.asyncio async def test_get_models_with_type_filter(self, mocker: MockFixture): diff --git a/test/backend/services/providers/test_silicon_provider.py b/test/backend/services/providers/test_silicon_provider.py index c9fd2b491..570a217d2 100644 --- a/test/backend/services/providers/test_silicon_provider.py +++ b/test/backend/services/providers/test_silicon_provider.py @@ -58,6 +58,48 @@ async def test_get_models_llm_success(self, mocker: MockFixture): assert result[0]["id"] == "gpt-4" assert result[0]["model_type"] == "llm" assert result[0]["model_tag"] == "chat" + assert "capacity_source" not in result[0] + + @pytest.mark.asyncio + async def test_get_models_llm_surfaces_capacity_hints(self, mocker: MockFixture): + """Provider token metadata is returned as advisory capacity hints.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "Qwen/Qwen3-Coder-480B-A35B-Instruct", + "name": "Qwen3 Coder", + "context_length": "262144", + "max_output_tokens": 8192, + "tokenizer": "qwen", + }, + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.silicon_provider.httpx.AsyncClient", + return_value=mock_cm + ) + + provider = SiliconModelProvider() + result = await provider.get_models({ + "model_type": "llm", + "api_key": "test-api-key", + }) + + assert result[0]["context_window_tokens"] == 262144 + assert result[0]["max_output_tokens"] == 8192 + assert result[0]["tokenizer_family"] == "qwen" + assert result[0]["capacity_source"] == "provider_candidate" @pytest.mark.asyncio async def test_get_models_vlm_success(self, mocker: MockFixture): diff --git a/test/backend/services/providers/test_tokenpony_provider.py b/test/backend/services/providers/test_tokenpony_provider.py index 58e514dbb..4f7021d0a 100644 --- a/test/backend/services/providers/test_tokenpony_provider.py +++ b/test/backend/services/providers/test_tokenpony_provider.py @@ -69,6 +69,49 @@ async def test_get_models_llm_success(self, mocker: MockFixture): assert result[0]["model_type"] == "llm" assert result[0]["model_tag"] == "chat" assert result[0]["max_tokens"] == 4096 + assert "capacity_source" not in result[0] + + @pytest.mark.asyncio + async def test_get_models_llm_surfaces_capacity_hints(self, mocker: MockFixture): + """Provider token metadata is returned as advisory capacity hints.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "id": "claude-3-opus", + "object": "model", + "owned_by": "openai", + "context_window": 128000, + "max_completion_tokens": "16384", + "tokenizer_family": "o200k_base", + } + ] + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_client) + mock_cm.__aexit__ = AsyncMock(return_value=None) + + mocker.patch( + "backend.services.providers.tokenpony_provider.httpx.AsyncClient", + return_value=mock_cm + ) + + provider = TokenPonyModelProvider() + result = await provider.get_models({ + "model_type": "llm", + "api_key": "test-api-key", + }) + + assert result[0]["context_window_tokens"] == 128000 + assert result[0]["max_output_tokens"] == 16384 + assert result[0]["tokenizer_family"] == "o200k_base" + assert result[0]["capacity_source"] == "provider_candidate" @pytest.mark.asyncio async def test_get_models_embedding_success(self, mocker: MockFixture): @@ -828,4 +871,3 @@ async def test_get_models_llm_has_max_tokens(self, mocker: MockFixture): assert len(result) == 1 assert result[0]["max_tokens"] == 4096 - diff --git a/test/backend/services/test_agent_service.py b/test/backend/services/test_agent_service.py index 6cd7b5da4..468205286 100644 --- a/test/backend/services/test_agent_service.py +++ b/test/backend/services/test_agent_service.py @@ -632,6 +632,10 @@ async def test_get_creating_sub_agent_info_impl_success(mock_get_current_user_in result = await get_creating_sub_agent_info_impl(authorization="Bearer token") # Assert + # W2 added `requested_output_tokens` to the response shape at + # agent_service.py:1112. The mocked `search_agent_info` payload does not + # include the key, so `agent_info.get("requested_output_tokens")` is None + # in the returned dict. expected_result = { "agent_id": 456, "name": "agent_name", @@ -641,6 +645,7 @@ async def test_get_creating_sub_agent_info_impl_success(mock_get_current_user_in "model_name": "test_model", "model_id": None, "max_steps": 5, + "requested_output_tokens": None, "business_description": "Sub agent", "duty_prompt": "Sub duty prompt", "constraint_prompt": "Sub constraint prompt", @@ -3727,6 +3732,7 @@ def mock_agent_request(): query="test query", history=[], minio_files=[], + requested_output_tokens=4096, is_debug=False, ) @@ -3766,7 +3772,21 @@ async def test_prepare_agent_run( assert memory_context == mock_memory_context mock_build_memory_context.assert_called_once_with( "test_user", "test_tenant", 1, skip_query=False) - mock_create_run_info.assert_called_once() + mock_create_run_info.assert_called_once_with( + agent_id=1, + minio_files=[], + query="test query", + history=[], + tenant_id="test_tenant", + user_id="test_user", + language="zh", + allow_memory_search=True, + is_debug=False, + override_version_no=None, + override_model_id=None, + requested_output_tokens=4096, + tool_params=None, + ) mock_agent_run_manager.register_agent_run.assert_called_once_with( 123, mock_run_info, "test_user") @@ -9204,6 +9224,24 @@ def test_get_agent_call_relationship_impl_deep_recursion(mock_query_sub, mock_se assert "sub_agents" in result +# W2 introduced `_validate_requested_output_tokens_for_agent` on the +# update/import path. The existing update_agent_info_impl_* / import_agent_* +# tests build their request via `MagicMock(spec=AgentInfoRequest)` and never +# wire `.requested_output_tokens = None`, so the validator either fails the +# `> max_output_tokens` comparison on two MagicMocks or AttributeErrors on the +# field. None of these tests are about output-reservation behavior, so we +# autouse-stub the validator for this section. Tests that need to exercise +# the validator can still `mock.patch` it locally; module-level autouse loses +# to per-test patches. +@pytest.fixture(autouse=True) +def _stub_requested_output_tokens_validator(): + with patch( + "backend.services.agent_service._validate_requested_output_tokens_for_agent", + return_value=None, + ): + yield + + # Tests for update_agent_info_impl skill handling exception @patch("backend.services.agent_service.skill_db.create_or_update_skill_by_skill_info") @patch("backend.services.agent_service.skill_db.query_skill_instances_by_agent_id") @@ -10037,7 +10075,18 @@ async def test_import_agent_by_agent_id_publish_version_error( mock_agent_info.business_logic_model_name = None mock_agent_info.prompt_template_id = None mock_agent_info.prompt_template_name = None - + # W2 added `requested_output_tokens` to ExportAndImportAgentInfo and + # import_agent_by_agent_id reads it directly at agent_service.py:1874. + # MagicMock(spec=...) on a Pydantic v2 model does not always expose + # field-level attributes through dir(), so the access AttributeErrors + # unless we set it explicitly. + mock_agent_info.requested_output_tokens = None + + # Configure the three patched mocks so the flow reaches the publish branch: + # - query_all_tools() must return an iterable (empty list -> no tool loop) + # - create_agent(...) must return a dict so `new_agent["agent_id"]` is an int + # - publish_version_impl(...) must raise so the under-test exception handler + # at agent_service.py:1899-1901 actually fires mock_query_tools.return_value = [] mock_create.return_value = {"agent_id": 100} mock_publish.side_effect = Exception("Publish error") diff --git a/test/backend/services/test_model_capacity_suggestion_service.py b/test/backend/services/test_model_capacity_suggestion_service.py new file mode 100644 index 000000000..fc6ffdc67 --- /dev/null +++ b/test/backend/services/test_model_capacity_suggestion_service.py @@ -0,0 +1,181 @@ +import os +import sys + +import pytest + +backend_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../backend")) +if backend_dir not in sys.path: + sys.path.append(backend_dir) + +from services.model_capacity_suggestion_service import ( + CapacitySuggestionMatchKind, + pick_provider, + pick_provider_from_base_url, + suggest_capacity, +) + + +class Profile: + def __init__( + self, + context_window_tokens, + max_output_tokens, + capability_profile_version, + max_input_tokens=None, + default_output_reserve_tokens=4096, + tokenizer_family="test-tokenizer", + ): + self.context_window_tokens = context_window_tokens + self.max_input_tokens = max_input_tokens + self.max_output_tokens = max_output_tokens + self.default_output_reserve_tokens = default_output_reserve_tokens + self.tokenizer_family = tokenizer_family + self.capability_profile_version = capability_profile_version + + +CATALOG = { + ("openai", "gpt-4o"): Profile(128_000, 16_384, "openai/gpt-4o@1"), + ("dashscope", "qwen-plus"): Profile(131_072, 16_384, "dashscope/qwen-plus@1"), + ("other", "qwen-plus"): Profile(131_072, 16_384, "other/qwen-plus@1"), + ("silicon", "deepseek-ai/DeepSeek-V4-Flash"): Profile( + 1_000_000, + 384_000, + "silicon/deepseek-v4-flash@1", + ), + ("silicon", "Pro/moonshotai/Kimi-K2.6"): Profile( + 262_144, + 131_072, + "silicon/kimi-k2.6@1", + ), +} + + +def test_suggest_capacity_catalog_exact_from_base_url(): + result = suggest_capacity( + model_name="gpt-4o", + base_url="https://api.openai.com/v1", + model_type="llm", + catalog=CATALOG, + ) + + assert result.match_kind == CapacitySuggestionMatchKind.CATALOG_EXACT + assert result.suggested_provider == "openai" + assert result.canonical_model_name == "gpt-4o" + assert result.capability_profile_version == "openai/gpt-4o@1" + assert result.capacity_source_on_accept == "operator" + assert result.suggestions.context_window_tokens == 128_000 + assert result.suggestions.max_output_tokens == 16_384 + + +def test_suggest_capacity_catalog_exact_case_insensitive(): + result = suggest_capacity( + model_name="GPT-4o", + provider_hint="openai", + model_type="llm", + catalog=CATALOG, + ) + + assert result.match_kind == CapacitySuggestionMatchKind.CATALOG_EXACT + assert result.canonical_model_name == "gpt-4o" + + +def test_suggest_capacity_catalog_fuzzy_normalized_name(): + result = suggest_capacity( + model_name="Deepseek V4 Flash", + provider_hint="silicon", + model_type="llm", + catalog=CATALOG, + ) + + assert result.match_kind == CapacitySuggestionMatchKind.CATALOG_FUZZY + assert result.suggested_provider == "silicon" + assert result.canonical_model_name == "deepseek-ai/DeepSeek-V4-Flash" + assert result.capability_profile_version == "silicon/deepseek-v4-flash@1" + + +def test_suggest_capacity_catalog_fuzzy_unique_final_segment(): + result = suggest_capacity( + model_name="Kimi-K2.6", + provider_hint="silicon", + model_type="llm", + catalog=CATALOG, + ) + + assert result.match_kind == CapacitySuggestionMatchKind.CATALOG_FUZZY + assert result.canonical_model_name == "Pro/moonshotai/Kimi-K2.6" + + +def test_suggest_capacity_rejects_ambiguous_providerless_model(): + result = suggest_capacity( + model_name="qwen-plus", + base_url="http://localhost:8000/v1", + model_type="llm", + catalog=CATALOG, + ) + + assert result.match_kind == CapacitySuggestionMatchKind.NONE + assert result.suggestions is None + + +def test_suggest_capacity_flag_off_returns_none(): + result = suggest_capacity( + model_name="gpt-4o", + base_url="https://api.openai.com/v1", + model_type="llm", + catalog=CATALOG, + enabled=False, + ) + + assert result.match_kind == CapacitySuggestionMatchKind.NONE + assert result.suggestions is None + assert "disabled" in result.match_explanation + + +def test_suggest_capacity_unsupported_model_type_returns_none(): + result = suggest_capacity( + model_name="gpt-4o", + base_url="https://api.openai.com/v1", + model_type="embedding", + catalog=CATALOG, + ) + + assert result.match_kind == CapacitySuggestionMatchKind.NONE + assert result.suggestions is None + + +def test_suggest_capacity_empty_model_name_raises(): + with pytest.raises(ValueError, match="model_name is required"): + suggest_capacity(model_name="", base_url="https://api.openai.com/v1", catalog=CATALOG) + + +def test_pick_provider_prefers_hint_then_base_url_then_unique_catalog(): + assert pick_provider("dashscope", "https://api.openai.com/v1", "gpt-4o", CATALOG) == "dashscope" + assert pick_provider(None, "https://api.openai.com/v1", "gpt-4o", CATALOG) == "openai" + assert pick_provider(None, None, "Kimi-K2.6", CATALOG) == "silicon" + + +def test_pick_provider_from_base_url_uses_shared_host_map(): + assert pick_provider_from_base_url("https://dashscope.aliyuncs.com/compatible-mode/v1") == "dashscope" + assert pick_provider_from_base_url("https://api.siliconflow.cn/v1") == "silicon" + assert pick_provider_from_base_url("https://api.tokenpony.ai/v1") == "tokenpony" + assert pick_provider_from_base_url("http://localhost:8000/v1") is None + + +def test_pick_provider_from_base_url_recognises_extended_patterns(): + # Patterns added to mirror frontend PROVIDER_HINTS (modelConfig.ts). + assert pick_provider_from_base_url("https://api.deepseek.com/v1") == "deepseek" + assert pick_provider_from_base_url("https://api.jina.ai/v1") == "jina" + # Broader OpenAI pattern: Azure OpenAI hosted endpoints also resolve. + assert pick_provider_from_base_url("https://myorg.openai.azure.com/v1") == "openai" + # Aliyun generic host without "dashscope" substring still resolves to + # dashscope so capacity lookup can hit the existing dashscope catalog. + assert pick_provider_from_base_url("https://bailian.aliyuncs.com/v1") == "dashscope" + # Full-URL substring matching: self-hosted reverse proxy with the + # provider name in the path is recognised (matches frontend behaviour). + assert pick_provider_from_base_url("https://corp.example.com/openai/v1") == "openai" + + +def test_pick_provider_from_base_url_dashscope_wins_over_aliyuncs(): + # Both substrings present; order in HOST_PROVIDER_PATTERNS makes + # dashscope win, which is the correct (more-specific) routing. + assert pick_provider_from_base_url("https://dashscope.aliyuncs.com/v1") == "dashscope" diff --git a/test/backend/services/test_model_management_service.py b/test/backend/services/test_model_management_service.py index 5bdcb4722..9ea88306a 100644 --- a/test/backend/services/test_model_management_service.py +++ b/test/backend/services/test_model_management_service.py @@ -108,6 +108,8 @@ def model_dump(self, *args, **kwargs): consts_const_mod.LOCALHOST_IP = "127.0.0.1" consts_const_mod.LOCALHOST_NAME = "localhost" consts_const_mod.DOCKER_INTERNAL_HOST = "host.docker.internal" +consts_const_mod.CAPACITY_SUGGESTION_ENABLED = True +consts_const_mod.CAPACITY_VISIBILITY_ENABLED = True consts_const_mod.DATA_PROCESS_SERVICE = "http://data-process" consts_const_mod.FILE_PREVIEW_SIZE_LIMIT = 100 * 1024 * 1024 consts_const_mod.MAX_CONCURRENT_UPLOADS = 5 @@ -1022,6 +1024,57 @@ async def test_update_single_model_for_tenant_success_single_model(): ) +async def test_update_single_model_for_tenant_mirrors_max_output_into_legacy_max_tokens(): + """LLM updates carrying max_output_tokens must mirror into the legacy + max_tokens column so the SDK's pre-W2 auto-fill cannot read a stale value + and trip CallerMaxTokensOverrideForbidden at the W2 dispatch boundary. + """ + svc = import_svc() + + existing_models = [ + {"model_id": 1, "model_type": "llm", "display_name": "name", "max_tokens": 204800}, + ] + model_data = { + "model_id": 1, + "display_name": "name", + "max_output_tokens": 131072, + # No explicit max_tokens — caller relies on backend coercion. + } + + with mock.patch.object(svc, "get_models_by_display_name", return_value=existing_models), \ + mock.patch.object(svc, "update_model_record") as mock_update: + await svc.update_single_model_for_tenant("u1", "t1", "name", model_data) + + update_args = mock_update.call_args.args[1] + assert update_args["max_output_tokens"] == 131072 + assert update_args["max_tokens"] == 131072 + + +async def test_update_single_model_for_tenant_preserves_embedding_max_tokens(): + """Embedding rows must NOT have max_tokens mirrored from max_output_tokens — + max_tokens is repurposed as the vector dimension on those rows. + """ + svc = import_svc() + + existing_models = [ + {"model_id": 10, "model_type": "embedding", "display_name": "emb", "max_tokens": 4096}, + ] + # Defensive caller accidentally passes max_output_tokens on an embedding row. + model_data = { + "model_id": 10, + "display_name": "emb", + "max_output_tokens": 8192, + } + + with mock.patch.object(svc, "get_models_by_display_name", return_value=existing_models), \ + mock.patch.object(svc, "update_model_record") as mock_update: + await svc.update_single_model_for_tenant("u1", "t1", "emb", model_data) + + update_args = mock_update.call_args.args[1] + # Embedding rows skip the coercion, so legacy max_tokens stays untouched. + assert "max_tokens" not in update_args + + async def test_update_single_model_for_tenant_conflict_new_display_name(): """Updating to a new conflicting display_name raises ValueError.""" svc = import_svc() @@ -1705,3 +1758,268 @@ async def test_create_model_for_tenant_embedding_with_api_key_sets_ssl_verify_tr assert mock_create.call_count == 1 create_args = mock_create.call_args[0][0] assert create_args["ssl_verify"] is True + + +@pytest.mark.asyncio +async def test_batch_create_models_for_tenant_update_branch_persists_operator_capacity(): + """Re-confirming a batch with operator-marked capacity updates W1/W2 columns. + + Regression test for the gap that left glm-5.x style rows with NULL + W2 columns: the batch_create update branch previously only checked + legacy max_tokens for changes, so a user who tweaked the top-level + batch defaults and re-confirmed could not push the new + context_window_tokens / max_output_tokens onto an existing row. + """ + svc = import_svc() + + existing_row = { + "model_id": 42, + "model_repo": "dashscope", + "model_name": "glm-5.2", + "max_tokens": 31920, + "context_window_tokens": None, + "max_output_tokens": None, + "capacity_source": None, + } + + batch_payload = { + "provider": "dashscope", + "type": "llm", + "models": [ + { + "id": "dashscope/glm-5.2", + "max_tokens": 31920, + "context_window_tokens": 200000, + "max_output_tokens": 31920, + "default_output_reserve_tokens": 4096, + "tokenizer_family": "qwen", + "capacity_source": "operator", + } + ], + "api_key": "dash-key", + } + + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[existing_row]), \ + mock.patch.object(svc, "delete_model_record"), \ + mock.patch.object(svc, "split_repo_name", return_value=("dashscope", "glm-5.2")), \ + mock.patch.object(svc, "add_repo_to_name", return_value="dashscope/glm-5.2"), \ + mock.patch.object(svc, "update_model_record") as mock_update, \ + mock.patch.object(svc, "create_model_record"): + + await svc.batch_create_models_for_tenant("u1", "t1", batch_payload) + + mock_update.assert_called_once() + called_model_id, called_update_data, *_ = mock_update.call_args[0] + assert called_model_id == 42 + assert called_update_data["context_window_tokens"] == 200000 + assert called_update_data["max_output_tokens"] == 31920 + assert called_update_data["default_output_reserve_tokens"] == 4096 + assert called_update_data["tokenizer_family"] == "qwen" + assert called_update_data["capacity_source"] == "operator" + + +@pytest.mark.asyncio +async def test_batch_create_models_for_tenant_update_branch_skips_provider_candidate_capacity(): + """Provider-discovered hints must not auto-overwrite an existing row. + + Even when the catalog response contains rich inference_metadata, those + values stay tagged capacity_source="provider_candidate" until the + operator accepts them. Refreshing the provider list must not + silently rewrite a row's operator-set capacity (or its NULLs) with + catalog hints. + """ + svc = import_svc() + + existing_row = { + "model_id": 7, + "model_repo": "dashscope", + "model_name": "glm-5.1", + "max_tokens": 8192, + "context_window_tokens": None, + "max_output_tokens": None, + "capacity_source": None, + } + + batch_payload = { + "provider": "dashscope", + "type": "llm", + "models": [ + { + "id": "dashscope/glm-5.1", + "max_tokens": 8192, + "context_window_tokens": 128000, + "max_output_tokens": 8192, + "tokenizer_family": "qwen", + "capacity_source": "provider_candidate", + } + ], + "api_key": "dash-key", + } + + with mock.patch.object(svc, "get_models_by_tenant_factory_type", return_value=[existing_row]), \ + mock.patch.object(svc, "delete_model_record"), \ + mock.patch.object(svc, "split_repo_name", return_value=("dashscope", "glm-5.1")), \ + mock.patch.object(svc, "add_repo_to_name", return_value="dashscope/glm-5.1"), \ + mock.patch.object(svc, "update_model_record") as mock_update, \ + mock.patch.object(svc, "create_model_record"): + + await svc.batch_create_models_for_tenant("u1", "t1", batch_payload) + + # max_tokens didn't change between existing (8192) and incoming + # (8192), so no update is needed at all. If the implementation + # were treating provider_candidate as authoritative, update would + # fire with the W2 fields. + if mock_update.called: + _, called_update_data, *_ = mock_update.call_args[0] + assert "context_window_tokens" not in called_update_data + assert "max_output_tokens" not in called_update_data + assert "tokenizer_family" not in called_update_data + assert called_update_data.get("capacity_source") != "provider_candidate" + + +def test_get_capacity_coverage_filters_bare_llm_vlm_rows(): + svc = import_svc() + + records = [ + { + "model_id": 1, + "model_repo": "", + "model_name": "gpt-4o", + "model_factory": "openai", + "model_type": "llm", + "context_window_tokens": 128000, + "max_output_tokens": 16384, + "max_tokens": 16384, + "base_url": "https://api.openai.com/v1", + }, + { + "model_id": 2, + "model_repo": "", + "model_name": "glm-5", + "model_factory": "OpenAI-API-Compatible", + "model_type": "llm", + "context_window_tokens": None, + "max_output_tokens": None, + "max_tokens": 131072, + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", + }, + { + "model_id": 3, + "model_repo": "", + "model_name": "vision-model", + "model_factory": "custom", + "model_type": "vlm", + "context_window_tokens": 32000, + "max_output_tokens": None, + "max_tokens": 8192, + "base_url": "https://example.com/v1", + }, + { + "model_id": 4, + "model_repo": "", + "model_name": "embedding-model", + "model_factory": "openai", + "model_type": "embedding", + "context_window_tokens": None, + "max_output_tokens": None, + "max_tokens": 1536, + "base_url": "https://api.openai.com/v1", + }, + { + "model_id": 5, + "model_repo": "", + "model_name": "rerank-model", + "model_factory": "custom", + "model_type": "rerank", + "context_window_tokens": None, + "max_output_tokens": None, + "max_tokens": 512, + "base_url": "https://example.com/v1", + }, + ] + + with mock.patch.object(svc, "get_model_records", return_value=records), \ + mock.patch.object(svc, "_capacity_suggestion_available", side_effect=[True, False]): + result = svc.get_capacity_coverage("tenant-a") + + assert result["total_llm_vlm"] == 3 + assert result["bare_count"] == 2 + assert [model["model_id"] for model in result["bare_models"]] == [2, 3] + assert result["bare_models"][0]["max_tokens"] == 131072 + assert result["bare_models"][0]["suggestion_available"] is True + assert result["bare_models"][1]["suggestion_available"] is False + + +def test_get_capacity_coverage_visibility_flag_off(): + svc = import_svc() + + with mock.patch.object(svc, "CAPACITY_VISIBILITY_ENABLED", False), \ + mock.patch.object(svc, "get_model_records") as mock_get_records: + result = svc.get_capacity_coverage("tenant-a") + + assert result == {"total_llm_vlm": 0, "bare_count": 0, "bare_models": []} + mock_get_records.assert_not_called() + + +def test_capacity_suggestion_available_uses_catalog_matcher(): + svc = import_svc() + + model = { + "model_id": 10, + "model_repo": "", + "model_name": "gpt-4o", + "model_factory": "openai", + "model_type": "llm", + "base_url": "https://api.openai.com/v1", + } + fake_result = mock.MagicMock() + fake_result.match_kind = svc.CapacitySuggestionMatchKind.CATALOG_EXACT + + with mock.patch.object(svc, "suggest_capacity", return_value=fake_result) as mock_suggest: + assert svc._capacity_suggestion_available(model) is True + + mock_suggest.assert_called_once_with( + model_name="gpt-4o", + base_url="https://api.openai.com/v1", + provider_hint="openai", + model_type="llm", + enabled=True, + ) + + +def test_capacity_suggestion_available_records_error_on_exception(): + """A catalog-matcher exception falls back to False AND increments the + coverage-error counter. Without the counter a corrupt catalog entry would + silently flip every row's suggestion_available to False with zero signal. + """ + svc = import_svc() + + model = { + "model_id": 42, + "model_repo": "", + "model_name": "broken-model", + "model_factory": "openai", + "model_type": "llm", + "base_url": "https://api.openai.com/v1", + } + + with mock.patch.object(svc, "suggest_capacity", side_effect=RuntimeError("catalog corrupt")), \ + mock.patch.object(svc, "_record_capacity_coverage_error") as mock_record: + assert svc._capacity_suggestion_available(model) is False + + mock_record.assert_called_once() + recorded_args = mock_record.call_args[0] + assert recorded_args[0] == 42 + assert isinstance(recorded_args[1], RuntimeError) + + +def test_record_capacity_coverage_error_no_op_when_counter_disabled(): + """The recorder must not raise when OpenTelemetry is unavailable; the + counter is None and the call becomes a no-op so coverage scans keep + working in deployments without telemetry installed. + """ + svc = import_svc() + + with mock.patch.object(svc, "_capacity_suggestion_coverage_errors_total", None): + # Should not raise. + svc._record_capacity_coverage_error(7, RuntimeError("boom")) diff --git a/test/backend/services/test_model_provider_service.py b/test/backend/services/test_model_provider_service.py index 1b3af74fc..b88cb38a3 100644 --- a/test/backend/services/test_model_provider_service.py +++ b/test/backend/services/test_model_provider_service.py @@ -138,6 +138,32 @@ def __init__(self): ]: sys.modules.setdefault(module_path, mock.MagicMock()) + +# Provide real implementations for the utils.model_name_utils helpers used by +# the module under test. Without these, attribute access on the MagicMock +# yields a callable that returns yet another MagicMock, which silently breaks +# every dict-key lookup downstream (`existing_model_map[]` never +# matches the string id sent by the provider response). +def _real_add_repo_to_name(model_repo, model_name): + if "/" in (model_name or ""): + return model_name + if model_repo: + return f"{model_repo}/{model_name}" + return model_name + + +def _real_split_repo_name(full_name): + if not full_name: + return ("", "") + if "/" in full_name: + head, _, tail = full_name.rpartition("/") + return (head, tail) + return ("", full_name) + + +sys.modules["utils.model_name_utils"].add_repo_to_name = _real_add_repo_to_name +sys.modules["utils.model_name_utils"].split_repo_name = _real_split_repo_name + # services.providers.base should NOT be mocked as it contains _classify_provider_error used in tests # SiliconModelProvider and ModelEngineProvider will be imported from their real modules @@ -211,6 +237,45 @@ class _TimeoutExceptionStub(Exception): ) +# ============================================================================ +# Test helpers +# ============================================================================ + +import contextlib + + +@contextlib.contextmanager +def _patch_provider_module_constant(module_basename: str, attr: str, value): + """Patch a constant on every sys.modules entry that exposes a provider + module under both `services.providers.` and + `backend.services.providers.` keys. + + Production code imports providers via the non-`backend.` path + (`from services.providers.silicon_provider import ...`) while many tests + import via the `backend.` path. When both keys are loaded by an earlier + test, they reference distinct module objects with independent name + bindings for constants such as SILICON_GET_URL, so a mock.patch that + targets only one path silently misses. This helper patches every loaded + path so the test is order-independent. + """ + candidate_paths = ( + f"services.providers.{module_basename}", + f"backend.services.providers.{module_basename}", + ) + patches = [] + for path in candidate_paths: + module = sys.modules.get(path) + if module is not None and hasattr(module, attr): + patcher = mock.patch.object(module, attr, value) + patcher.start() + patches.append(patcher) + try: + yield + finally: + for patcher in reversed(patches): + patcher.stop() + + # ============================================================================ # Test-cases for SiliconModelProvider.get_models # ============================================================================ @@ -221,12 +286,12 @@ async def test_get_models_llm_success(): """Silicon provider should append chat tag/type for LLM models.""" provider_config = {"model_type": "llm", "api_key": "test-key"} - # Patch HTTP client & constant inside the provider module + # Patch HTTP client & constant inside the provider module. + # SILICON_GET_URL is patched on every loaded path (see helper docstring). with mock.patch( "backend.services.providers.silicon_provider.httpx.AsyncClient" - ) as mock_client, mock.patch( - "backend.services.providers.silicon_provider.SILICON_GET_URL", - "https://silicon.com", + ) as mock_client, _patch_provider_module_constant( + "silicon_provider", "SILICON_GET_URL", "https://silicon.com" ): # Prepare mocked http client / response behaviour @@ -266,9 +331,8 @@ async def test_get_models_embedding_success(): with mock.patch( "backend.services.providers.silicon_provider.httpx.AsyncClient" - ) as mock_client, mock.patch( - "backend.services.providers.silicon_provider.SILICON_GET_URL", - "https://silicon.com", + ) as mock_client, _patch_provider_module_constant( + "silicon_provider", "SILICON_GET_URL", "https://silicon.com" ): mock_client_instance = mock.AsyncMock() @@ -305,9 +369,8 @@ async def test_get_models_unknown_type(): with mock.patch( "backend.services.providers.silicon_provider.httpx.AsyncClient" - ) as mock_client, mock.patch( - "backend.services.providers.silicon_provider.SILICON_GET_URL", - "https://silicon.com", + ) as mock_client, _patch_provider_module_constant( + "silicon_provider", "SILICON_GET_URL", "https://silicon.com" ): result = await SiliconModelProvider().get_models(provider_config) @@ -322,9 +385,8 @@ async def test_get_models_exception(): with mock.patch( "backend.services.providers.silicon_provider.httpx.AsyncClient" - ) as mock_client, mock.patch( - "backend.services.providers.silicon_provider.SILICON_GET_URL", - "https://silicon.com", + ) as mock_client, _patch_provider_module_constant( + "silicon_provider", "SILICON_GET_URL", "https://silicon.com" ): mock_client_instance = mock.AsyncMock() @@ -401,6 +463,143 @@ async def test_prepare_model_dict_llm(): assert result == expected +@pytest.mark.asyncio +async def test_prepare_model_dict_does_not_persist_provider_capacity_candidates(): + """Provider capacity candidates remain UI hints until an operator saves them. + + Per the W1/W2 plan, _extract_capacity_hints tags provider-discovered + capacity values with capacity_source="provider_candidate" so the + catalog UI can show them as suggestions. They must not auto-persist + on batch_create; only operator acceptance (capacity_source="operator") + can write to the row. The original assertion only checked the dumped + result, which is trivially controlled by the mock; the strengthened + assertion below pins ModelRequest's constructor kwargs so the + contract is enforced regardless of what model_dump returns. + """ + with mock.patch( + "backend.services.model_provider_service.split_repo_name", + return_value=("openai", "gpt-4"), + ), mock.patch( + "backend.services.model_provider_service.add_repo_to_name", + return_value="openai/gpt-4", + ), mock.patch( + "backend.services.model_provider_service.ModelRequest" + ) as mock_model_request: + + mock_model_req_instance = mock.MagicMock() + dump_dict = { + "model_factory": "openai", + "model_name": "gpt-4", + "model_type": "llm", + "api_key": "test-key", + "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS, + "display_name": "openai/gpt-4", + } + mock_model_req_instance.model_dump.return_value = dump_dict + mock_model_request.return_value = mock_model_req_instance + + model = { + "id": "openai/gpt-4", + "model_type": "llm", + "max_tokens": sys.modules["consts.const"].DEFAULT_LLM_MAX_TOKENS, + "context_window_tokens": 128000, + "max_output_tokens": 16384, + "tokenizer_family": "o200k_base", + "capacity_source": "provider_candidate", + } + + result = await prepare_model_dict( + "openai", + model, + "https://api.openai.com/v1", + "test-key", + ) + + # Result-level: the dumped dict (controlled by the mock) doesn't + # carry capacity hints downstream. + assert "context_window_tokens" not in result + assert "max_output_tokens" not in result + assert "tokenizer_family" not in result + assert "capacity_source" not in result + + # Contract-level: prepare_model_dict must NOT thread provider + # candidates into ModelRequest. Without this assertion the bug + # we just fixed -- threading every W2 field through unconditionally + # -- would slip past the result-level check because the mock + # absorbs any kwargs silently. + _, kwargs = mock_model_request.call_args + assert "context_window_tokens" not in kwargs + assert "max_output_tokens" not in kwargs + assert "max_input_tokens" not in kwargs + assert "default_output_reserve_tokens" not in kwargs + assert "tokenizer_family" not in kwargs + assert "capacity_source" not in kwargs + assert "capability_profile_version" not in kwargs + + +@pytest.mark.asyncio +async def test_prepare_model_dict_persists_operator_capacity(): + """Operator-saved capacity reaches ModelRequest and lands on the row. + + Regression test for the glm-5.1/glm-5.2 production incident: the + frontend batch-add path resolves user-typed top-level batch defaults + (or per-row gear values) and submits them with + capacity_source="operator". Before the fix, prepare_model_dict + silently dropped every W1/W2 field on the floor and only the legacy + max_tokens mirror persisted -- leaving DB rows with + context_window_tokens=NULL and max_output_tokens=NULL. + """ + with mock.patch( + "backend.services.model_provider_service.split_repo_name", + return_value=("dashscope", "glm-5.2"), + ), mock.patch( + "backend.services.model_provider_service.add_repo_to_name", + return_value="dashscope/glm-5.2", + ), mock.patch( + "backend.services.model_provider_service.ModelRequest" + ) as mock_model_request: + + mock_model_req_instance = mock.MagicMock() + mock_model_req_instance.model_dump.return_value = { + "model_factory": "dashscope", + "model_name": "glm-5.2", + "model_type": "llm", + "max_tokens": 31920, + "display_name": "dashscope/glm-5.2", + } + mock_model_request.return_value = mock_model_req_instance + + model = { + "id": "dashscope/glm-5.2", + "model_type": "llm", + "max_tokens": 31920, + "context_window_tokens": 200000, + "max_input_tokens": None, + "max_output_tokens": 31920, + "default_output_reserve_tokens": 4096, + "tokenizer_family": "qwen", + "capacity_source": "operator", + } + + await prepare_model_dict( + "dashscope", + model, + "https://dashscope.aliyuncs.com/compatible-mode/v1/", + "dash-key", + ) + + _, kwargs = mock_model_request.call_args + assert kwargs["context_window_tokens"] == 200000 + assert kwargs["max_output_tokens"] == 31920 + assert kwargs["default_output_reserve_tokens"] == 4096 + assert kwargs["tokenizer_family"] == "qwen" + # capacity_source is forced to "operator" by the prepare_model_dict + # contract: only operator-marked values reach the row, and the + # marker itself is normalized to the canonical value rather than + # echoing whatever the caller sent. + assert kwargs["capacity_source"] == "operator" + + @pytest.mark.asyncio async def test_prepare_model_dict_vlm(): """VLM models should behave like LLM: no emb dim check; chunk sizes None; base_url untouched.""" @@ -1182,6 +1381,37 @@ def test_merge_existing_model_tokens_verify_function_call(): tenant_id, provider, model_type) +def test_merge_existing_model_tokens_empty_model_repo_matches_bare_name(): + """Regression: DashScope-style rows have empty model_repo. The lookup key + must use add_repo_to_name so the row matches the bare "glm-4.7" id from + the provider response. The legacy code built "/glm-4.7" via raw + concatenation, so the merge silently no-opped -- same wire-key bug as + batch_create_models_for_tenant's delete loop. + """ + model_list = [{"id": "glm-4.7", "model_type": "llm"}] + tenant_id = "test-tenant" + provider = "dashscope" + model_type = "llm" + + existing_models = [ + { + "model_repo": "", + "model_name": "glm-4.7", + "max_tokens": 131072, + } + ] + + with mock.patch( + "backend.services.model_provider_service.get_models_by_tenant_factory_type", + return_value=existing_models, + ): + result = merge_existing_model_tokens( + model_list, tenant_id, provider, model_type + ) + + assert result[0]["max_tokens"] == 131072 + + # ============================================================================ # Test-cases for get_provider_models # ============================================================================ @@ -1873,9 +2103,8 @@ async def test_silicon_get_models_empty_list(): with mock.patch( "backend.services.providers.silicon_provider.httpx.AsyncClient" - ) as mock_client, mock.patch( - "backend.services.providers.silicon_provider.SILICON_GET_URL", - "https://silicon.com", + ) as mock_client, _patch_provider_module_constant( + "silicon_provider", "SILICON_GET_URL", "https://silicon.com" ): mock_client_instance = mock.AsyncMock() diff --git a/test/backend/utils/test_config_utils.py b/test/backend/utils/test_config_utils.py index 80fc3d483..6ed928814 100644 --- a/test/backend/utils/test_config_utils.py +++ b/test/backend/utils/test_config_utils.py @@ -1,7 +1,9 @@ import pytest import json import sys +import types from unittest.mock import patch +from pydantic import BaseModel, Field # Setup common mocks from test.common.test_mocks import setup_common_mocks, patch_minio_client_initialization @@ -9,9 +11,25 @@ # Initialize common mocks mocks = setup_common_mocks() + +class InvalidReservePolicy(Exception): + pass + + +class CapacityReservePolicy(BaseModel): + soft_limit_ratio: float = Field(default=0.8, gt=0, le=1) + soft_limit_ratio_source: str = "code_default" + + +capacity_budget_mock = types.ModuleType("nexent.core.models.capacity_budget") +capacity_budget_mock.CapacityReservePolicy = CapacityReservePolicy +capacity_budget_mock.InvalidReservePolicy = InvalidReservePolicy +sys.modules["nexent.core.models.capacity_budget"] = capacity_budget_mock + # Patch storage factory before importing with patch_minio_client_initialization(): from backend.utils.config_utils import ( + CONTEXT_SOFT_LIMIT_RATIO_KEY, safe_value, safe_list, get_env_key, @@ -215,6 +233,38 @@ def test_get_app_config_no_tenant_id(self, config_manager): result = config_manager.get_app_config("key") assert result == "" + @patch('backend.utils.config_utils.get_all_configs_by_tenant_id') + def test_get_capacity_reserve_policy_default(self, mock_get_configs, config_manager): + """Missing W2 soft-limit config should use policy default.""" + mock_get_configs.return_value = [] + + policy = config_manager.get_capacity_reserve_policy("tenant1") + + assert policy.soft_limit_ratio == 0.8 + assert policy.soft_limit_ratio_source == "code_default" + + @patch('backend.utils.config_utils.get_all_configs_by_tenant_id') + def test_get_capacity_reserve_policy_tenant_override(self, mock_get_configs, config_manager): + """Valid tenant W2 soft-limit config should be parsed and sourced.""" + mock_get_configs.return_value = [ + {"config_key": CONTEXT_SOFT_LIMIT_RATIO_KEY, "config_value": "0.75"} + ] + + policy = config_manager.get_capacity_reserve_policy("tenant1") + + assert policy.soft_limit_ratio == 0.75 + assert policy.soft_limit_ratio_source == "tenant_config" + + @patch('backend.utils.config_utils.get_all_configs_by_tenant_id') + def test_get_capacity_reserve_policy_invalid_override(self, mock_get_configs, config_manager): + """Invalid W2 soft-limit config should fail closed.""" + mock_get_configs.return_value = [ + {"config_key": CONTEXT_SOFT_LIMIT_RATIO_KEY, "config_value": "1.5"} + ] + + with pytest.raises(Exception, match=CONTEXT_SOFT_LIMIT_RATIO_KEY): + config_manager.get_capacity_reserve_policy("tenant1") + @patch('backend.utils.config_utils.insert_config') @patch('backend.utils.config_utils.get_all_configs_by_tenant_id') def test_set_single_config_success(self, mock_get_configs, mock_insert, config_manager): diff --git a/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed.py b/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed.py index 79dfd5a03..04b5950d6 100644 --- a/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed.py +++ b/test/sdk/core/agents/test_agent_context/unit/test_compress_if_needed.py @@ -65,6 +65,20 @@ def test_over_threshold_triggers_compression(self): ) assert "Summary of earlier steps" in all_text + def test_soft_input_budget_triggers_compression_before_legacy_threshold(self): + cm = make_cm(enabled=True, threshold=999999, keep_recent_steps=2, keep_recent_pairs=1) + cm.config.soft_input_budget_tokens = 10 + cm.config.hard_input_budget_tokens = 999999 + memory = make_memory_mixed(n_prev_pairs=3, n_curr_actions=2) + original = make_original_messages(memory) + current_run_start_idx = 6 + model = make_model('{"task_overview": "summary"}') + + result = cm.compress_if_needed(model, memory, original, current_run_start_idx) + + assert result is not None + model.assert_called_once() + def test_run_boundary_clears_current_cache(self): """Switching run (current_run_start_idx changes) and ensuring no current summary triggers, current cache should be cleared.""" cm = make_cm(enabled=True, threshold=1) @@ -186,4 +200,4 @@ def test_mixed_prev_and_curr_over_threshold(self): for m in result for b in (m.content if isinstance(m.content, list) else []) if isinstance(b, dict) ) - assert "Summary of earlier steps" in all_text \ No newline at end of file + assert "Summary of earlier steps" in all_text diff --git a/test/sdk/core/agents/test_context_component.py b/test/sdk/core/agents/test_context_component.py index 860f0ade2..d1bede0f8 100644 --- a/test/sdk/core/agents/test_context_component.py +++ b/test/sdk/core/agents/test_context_component.py @@ -782,6 +782,21 @@ def test_existing_fields_preserved(self): assert config.token_threshold == 5000 assert config.keep_recent_steps == 3 + def test_w2_budget_fields_default_to_legacy_threshold_mode(self): + config = summary_config_module.ContextManagerConfig() + assert config.soft_input_budget_tokens == 0 + assert config.hard_input_budget_tokens == 0 + + def test_w2_budget_fields_can_be_set(self): + config = summary_config_module.ContextManagerConfig( + token_threshold=8000, + soft_input_budget_tokens=7000, + hard_input_budget_tokens=9000, + ) + assert config.token_threshold == 8000 + assert config.soft_input_budget_tokens == 7000 + assert config.hard_input_budget_tokens == 9000 + class TestAgentConfigWithContextComponents: """Tests for AgentConfig with context_components field.""" @@ -812,4 +827,4 @@ def test_agent_config_default_context_components_none(self): if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/test/sdk/core/agents/test_nexent_agent.py b/test/sdk/core/agents/test_nexent_agent.py index 882e28514..83512c912 100644 --- a/test/sdk/core/agents/test_nexent_agent.py +++ b/test/sdk/core/agents/test_nexent_agent.py @@ -459,7 +459,9 @@ def test_create_model_success(nexent_agent_with_models, mock_model_config): # Verify the result assert result == mock_model_instance - # Verify OpenAIModel was constructed with correct parameters + # Verify OpenAIModel was constructed with correct parameters. + # W1 renamed the SDK's `max_tokens` kwarg to `max_output_tokens`; the + # production code path here builds the same kwarg under the new name. mock_openai_model_class.assert_called_once_with( observer=nexent_agent_with_models.observer, model_id=mock_model_config.model_name, @@ -471,7 +473,7 @@ def test_create_model_success(nexent_agent_with_models, mock_model_config): ssl_verify=True, display_name=mock_model_config.cite_name, extra_body=mock_model_config.extra_body, - max_tokens=mock_model_config.max_tokens, + max_output_tokens=mock_model_config.max_tokens, timeout_seconds=mock_model_config.timeout_seconds, ) @@ -491,7 +493,8 @@ def test_create_model_deep_thinking_success(nexent_agent_with_models, mock_deep_ # Verify the result assert result == mock_model_instance - # Verify OpenAIModel was constructed with correct parameters + # Verify OpenAIModel was constructed with correct parameters. + # W1 renamed the SDK's `max_tokens` kwarg to `max_output_tokens`. mock_openai_model_class.assert_called_once_with( observer=nexent_agent_with_models.observer, model_id=mock_deep_thinking_model_config.model_name, @@ -503,7 +506,7 @@ def test_create_model_deep_thinking_success(nexent_agent_with_models, mock_deep_ ssl_verify=True, display_name=mock_deep_thinking_model_config.cite_name, extra_body=mock_deep_thinking_model_config.extra_body, - max_tokens=mock_deep_thinking_model_config.max_tokens, + max_output_tokens=mock_deep_thinking_model_config.max_tokens, timeout_seconds=mock_deep_thinking_model_config.timeout_seconds, ) diff --git a/test/sdk/core/agents/test_run_agent.py b/test/sdk/core/agents/test_run_agent.py index 476337eae..314a43e3d 100644 --- a/test/sdk/core/agents/test_run_agent.py +++ b/test/sdk/core/agents/test_run_agent.py @@ -1,4 +1,5 @@ import types +import json import importlib.machinery import pytest import importlib @@ -283,6 +284,61 @@ def test_agent_run_thread_local_flow(basic_agent_run_info, monkeypatch): mock_nexent_instance.add_history_to_agent.assert_called_once_with(basic_agent_run_info.history) mock_nexent_instance.agent_run_with_observer.assert_called_once_with(query=basic_agent_run_info.query, reset=False) + +def test_agent_run_thread_binds_capacity_and_budget_snapshots(basic_agent_run_info, monkeypatch): + captured = {} + basic_agent_run_info.capacity_snapshot = {"capacity_fingerprint": "w1"} + basic_agent_run_info.safe_input_budget_snapshot = {"fingerprint": "w2"} + + monkeypatch.setattr( + run_agent, + "set_monitoring_capacity_snapshot", + lambda snapshot: captured.setdefault("capacity", snapshot), + ) + monkeypatch.setattr( + run_agent, + "set_monitoring_safe_input_budget_snapshot", + lambda snapshot: captured.setdefault("budget", snapshot), + ) + mock_nexent_instance = MagicMock(name="NexentAgentInstance") + monkeypatch.setattr(run_agent, "NexentAgent", MagicMock(return_value=mock_nexent_instance)) + + run_agent.agent_run_thread(basic_agent_run_info) + + assert captured["capacity"] == {"capacity_fingerprint": "w1"} + assert captured["budget"] == {"fingerprint": "w2"} + + +def test_emit_uncertainty_reserve_warning(basic_agent_run_info): + basic_agent_run_info.safe_input_budget_snapshot = { + "warnings": ["uncertainty_reserve_active"], + "fingerprint": "w2", + "w1_fingerprint": "w1", + "uncertainty_reserve_tokens": 12800, + "hard_input_budget_tokens": 114200, + } + + run_agent._emit_uncertainty_reserve_warning(basic_agent_run_info) + + basic_agent_run_info.observer.add_message.assert_called_once() + _, process_type, content = basic_agent_run_info.observer.add_message.call_args[0] + assert process_type == ProcessType.OTHER + payload = json.loads(content) + assert payload["code"] == "uncertainty_reserve_active" + assert payload["budget_fingerprint"] == "w2" + assert payload["uncertainty_reserve_tokens"] == 12800 + + +def test_emit_uncertainty_reserve_warning_noops_without_warning(basic_agent_run_info): + basic_agent_run_info.safe_input_budget_snapshot = { + "warnings": [], + "fingerprint": "w2", + } + + run_agent._emit_uncertainty_reserve_warning(basic_agent_run_info) + + basic_agent_run_info.observer.add_message.assert_not_called() + # Ensure no MCP-specific behaviour occurred basic_agent_run_info.observer.add_message.assert_not_called() diff --git a/test/sdk/core/models/test_capacity_budget.py b/test/sdk/core/models/test_capacity_budget.py new file mode 100644 index 000000000..7f55be097 --- /dev/null +++ b/test/sdk/core/models/test_capacity_budget.py @@ -0,0 +1,267 @@ +"""Unit tests for W2 safe-input-budget type skeleton.""" +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path + +import pytest +from pydantic import ValidationError + + +_SDK_ROOT = Path(__file__).resolve().parents[4] / "sdk" / "nexent" + +for pkg_name, pkg_path in ( + ("nexent", _SDK_ROOT), + ("nexent.core", _SDK_ROOT / "core"), + ("nexent.core.models", _SDK_ROOT / "core" / "models"), +): + if pkg_name not in sys.modules: + pkg = types.ModuleType(pkg_name) + pkg.__path__ = [str(pkg_path)] + sys.modules[pkg_name] = pkg + + +def _load(module_name: str, file_path: Path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + mod = importlib.util.module_from_spec(spec) + sys.modules[module_name] = mod + spec.loader.exec_module(mod) + return mod + + +_capacity_resolver = _load( + "nexent.core.models.capacity_resolver", + _SDK_ROOT / "core" / "models" / "capacity_resolver.py", +) +_capacity_budget = _load( + "nexent.core.models.capacity_budget", + _SDK_ROOT / "core" / "models" / "capacity_budget.py", +) + +CapacityReservePolicy = _capacity_budget.CapacityReservePolicy +InvalidReservePolicy = _capacity_budget.InvalidReservePolicy +NoSafeInputCapacity = _capacity_budget.NoSafeInputCapacity +RequestedOutputExceedsCapacity = _capacity_budget.RequestedOutputExceedsCapacity +RequestBudgetOverrides = _capacity_budget.RequestBudgetOverrides +ReserveExceedsCapacity = _capacity_budget.ReserveExceedsCapacity +SafeInputBudgetCalculator = _capacity_budget.SafeInputBudgetCalculator +UncertaintyReserveBasisUnknown = _capacity_budget.UncertaintyReserveBasisUnknown +W2_RESOLVER_VERSION = _capacity_budget.W2_RESOLVER_VERSION +compute_w2_fingerprint = _capacity_budget.compute_w2_fingerprint +ModelCapacitySnapshot = _capacity_resolver.ModelCapacitySnapshot + + +def _fingerprint(**overrides) -> str: + payload = { + "w2_resolver_version": W2_RESOLVER_VERSION, + "w1_fingerprint": "w1abc", + "provider": "openai", + "model_name": "gpt-4o", + "requested_output_tokens": 4096, + "output_reserve_source": "model_default", + "uncertainty_reserve_tokens": 12800, + "uncertainty_reserve_basis": "context_window_10pct", + "approved_profile_reserve_tokens": None, + "soft_limit_ratio": 0.8, + "soft_limit_ratio_source": "code_default", + "soft_input_budget_tokens": 88883, + "hard_input_budget_tokens": 111104, + "field_sources": {"soft_limit_ratio": "code_default"}, + "warnings": [], + } + payload.update(overrides) + return compute_w2_fingerprint(**payload) + + +def test_capacity_reserve_policy_defaults_to_w2_soft_limit(): + policy = CapacityReservePolicy() + + assert policy.soft_limit_ratio == 0.8 + assert policy.soft_limit_ratio_source == "code_default" + assert policy.approved_profile_reserve_tokens is None + + +def test_capacity_reserve_policy_rejects_invalid_ratio(): + with pytest.raises(ValidationError): + CapacityReservePolicy(soft_limit_ratio=0) + + with pytest.raises(ValidationError): + CapacityReservePolicy(soft_limit_ratio=1.01) + + +def test_compute_w2_fingerprint_is_deterministic_and_ignores_warnings(): + first = _fingerprint(warnings=["observe-only"]) + second = _fingerprint(warnings=["different warning"]) + + assert first == second + assert len(first) == 32 + + +def test_compute_w2_fingerprint_changes_when_contract_changes(): + first = _fingerprint() + second = _fingerprint(requested_output_tokens=8192) + + assert first != second + + +def _capacity_snapshot(**overrides) -> ModelCapacitySnapshot: + payload = { + "provider": "openai", + "model_name": "gpt-4o", + "context_window_tokens": 128_000, + "max_input_tokens": None, + "max_output_tokens": 16_384, + "default_output_reserve_tokens": 4_096, + "requested_output_tokens": 4_096, + "provider_input_limit_tokens": 123_904, + "tokenizer_family": "o200k_base", + "counting_mode": "estimated", + "unknown_capabilities": ["tokenizer"], + "field_sources": { + "context_window_tokens": "profile", + "max_output_tokens": "profile", + }, + "capability_profile_version": "openai/gpt-4o@1", + "fingerprint": "w1fingerprint", + } + payload.update(overrides) + return ModelCapacitySnapshot(**payload) + + +def test_calculator_combined_window_uses_10_percent_uncertainty_reserve(): + calculator = SafeInputBudgetCalculator() + + snap = calculator.calculate_safe_input_budget( + capacity_snapshot=_capacity_snapshot(), + reserve_policy=CapacityReservePolicy(), + ) + + assert snap.provider_input_limit_tokens == 128_000 - 4_096 + assert snap.uncertainty_reserve_tokens == 12_800 + assert snap.uncertainty_reserve_basis == "context_window_10pct" + assert snap.hard_input_budget_tokens == 111_104 + assert snap.soft_input_budget_tokens == 88_883 + assert snap.requested_output_tokens == 4_096 + assert snap.output_reserve_source == "model_default" + assert snap.w1_fingerprint == "w1fingerprint" + assert "uncertainty_reserve_active" in snap.warnings + assert len(snap.fingerprint) == 32 + + +def test_calculator_recomputes_provider_limit_for_request_override(): + calculator = SafeInputBudgetCalculator() + + snap = calculator.calculate_safe_input_budget( + capacity_snapshot=_capacity_snapshot(), + reserve_policy=CapacityReservePolicy(), + request_overrides=RequestBudgetOverrides(requested_output_tokens=8_192), + ) + + assert snap.requested_output_tokens == 8_192 + assert snap.output_reserve_source == "request" + assert snap.provider_input_limit_tokens == 128_000 - 8_192 + assert snap.hard_input_budget_tokens == (128_000 - 8_192) - 12_800 + + +def test_calculator_rejects_request_override_that_lowers_reserve(): + calculator = SafeInputBudgetCalculator() + + with pytest.raises(InvalidReservePolicy): + calculator.calculate_safe_input_budget( + capacity_snapshot=_capacity_snapshot(), + reserve_policy=CapacityReservePolicy(), + request_overrides=RequestBudgetOverrides(requested_output_tokens=2_048), + ) + + +def test_calculator_allows_agent_override_source(): + calculator = SafeInputBudgetCalculator() + + snap = calculator.calculate_safe_input_budget( + capacity_snapshot=_capacity_snapshot(), + reserve_policy=CapacityReservePolicy(), + requested_output_tokens=2_048, + output_reserve_source="agent", + ) + + assert snap.requested_output_tokens == 2_048 + assert snap.output_reserve_source == "agent" + + +def test_calculator_uses_approved_profile_reserve_for_separate_input_limit(): + calculator = SafeInputBudgetCalculator() + + snap = calculator.calculate_safe_input_budget( + capacity_snapshot=_capacity_snapshot( + context_window_tokens=None, + max_input_tokens=32_768, + provider_input_limit_tokens=32_768, + unknown_capabilities=["tokenizer"], + ), + reserve_policy=CapacityReservePolicy(approved_profile_reserve_tokens=512), + ) + + assert snap.provider_input_limit_tokens == 32_768 + assert snap.uncertainty_reserve_tokens == 512 + assert snap.uncertainty_reserve_basis == "approved_profile" + assert snap.hard_input_budget_tokens == 32_256 + + +def test_calculator_requires_context_window_for_10_percent_reserve(): + calculator = SafeInputBudgetCalculator() + + with pytest.raises(UncertaintyReserveBasisUnknown): + calculator.calculate_safe_input_budget( + capacity_snapshot=_capacity_snapshot( + context_window_tokens=None, + max_input_tokens=32_768, + provider_input_limit_tokens=32_768, + unknown_capabilities=["tokenizer"], + ), + reserve_policy=CapacityReservePolicy(), + ) + + +def test_calculator_rejects_requested_output_above_capacity(): + calculator = SafeInputBudgetCalculator() + + with pytest.raises(RequestedOutputExceedsCapacity): + calculator.calculate_safe_input_budget( + capacity_snapshot=_capacity_snapshot(max_output_tokens=8_000), + reserve_policy=CapacityReservePolicy(), + request_overrides=RequestBudgetOverrides(requested_output_tokens=8_192), + ) + + +def test_calculator_rejects_reserve_larger_than_provider_limit(): + calculator = SafeInputBudgetCalculator() + + with pytest.raises(ReserveExceedsCapacity): + calculator.calculate_safe_input_budget( + capacity_snapshot=_capacity_snapshot( + context_window_tokens=10_000, + max_input_tokens=100, + provider_input_limit_tokens=100, + unknown_capabilities=["tokenizer"], + ), + reserve_policy=CapacityReservePolicy(), + ) + + +def test_calculator_rejects_no_safe_input_capacity_after_output_reserve(): + calculator = SafeInputBudgetCalculator() + + with pytest.raises(NoSafeInputCapacity): + calculator.calculate_safe_input_budget( + capacity_snapshot=_capacity_snapshot( + context_window_tokens=4_096, + max_input_tokens=None, + max_output_tokens=8_192, + requested_output_tokens=4_096, + provider_input_limit_tokens=1, + unknown_capabilities=[], + ), + reserve_policy=CapacityReservePolicy(), + ) diff --git a/test/sdk/core/models/test_capacity_resolver.py b/test/sdk/core/models/test_capacity_resolver.py new file mode 100644 index 000000000..a81da3862 --- /dev/null +++ b/test/sdk/core/models/test_capacity_resolver.py @@ -0,0 +1,336 @@ +"""Unit tests for ModelCapacityResolver (W1).""" +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path + +# Build a minimal `nexent.core.models` package skeleton in sys.modules so we can +# import the capacity_resolver and tokenizer_registry modules without triggering +# the SDK's full __init__ chain (which pulls smolagents, mem0, etc.). +_SDK_ROOT = Path(__file__).resolve().parents[4] / "sdk" / "nexent" + +for pkg_name, pkg_path in ( + ("nexent", _SDK_ROOT), + ("nexent.core", _SDK_ROOT / "core"), + ("nexent.core.models", _SDK_ROOT / "core" / "models"), +): + if pkg_name not in sys.modules: + pkg = types.ModuleType(pkg_name) + pkg.__path__ = [str(pkg_path)] + sys.modules[pkg_name] = pkg + + +def _load(module_name: str, file_path: Path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + mod = importlib.util.module_from_spec(spec) + sys.modules[module_name] = mod + spec.loader.exec_module(mod) + return mod + + +_capacity_resolver = _load( + "nexent.core.models.capacity_resolver", + _SDK_ROOT / "core" / "models" / "capacity_resolver.py", +) +_load( + "nexent.core.models.tokenizer_registry", + _SDK_ROOT / "core" / "models" / "tokenizer_registry.py", +) + +CapabilityProfile = _capacity_resolver.CapabilityProfile +InvalidCapacityConfiguration = _capacity_resolver.InvalidCapacityConfiguration +ModelCapacitySnapshot = _capacity_resolver.ModelCapacitySnapshot +ProviderCapabilityUnknown = _capacity_resolver.ProviderCapabilityUnknown +RESOLVER_VERSION = _capacity_resolver.RESOLVER_VERSION +RequestedOutputExceedsCap = _capacity_resolver.RequestedOutputExceedsCap +compute_fingerprint = _capacity_resolver.compute_fingerprint +resolve_capacity = _capacity_resolver.resolve_capacity + +import pytest # noqa: E402 +from pydantic import ValidationError # noqa: E402 + + +def _gpt4o_profile() -> CapabilityProfile: + return CapabilityProfile( + provider="openai", + model_name="gpt-4o", + capability_profile_version="openai/gpt-4o@1", + window_shape="combined", + context_window_tokens=128_000, + max_output_tokens=16_384, + default_output_reserve_tokens=4_096, + tokenizer_family="o200k_base", + ) + + +def _separate_limit_profile() -> CapabilityProfile: + """A synthetic profile exercising the separate-input-limit path. + + No real day-one model uses this shape, but the budget code must support it. + """ + return CapabilityProfile( + provider="testprovider", + model_name="separate-limit-model", + capability_profile_version="testprovider/separate@1", + window_shape="separate", + context_window_tokens=None, + max_input_tokens=32_768, + max_output_tokens=4_096, + default_output_reserve_tokens=1_024, + tokenizer_family=None, + ) + + +def _catalog(*profiles: CapabilityProfile) -> dict: + return {(p.provider, p.model_name): p for p in profiles} + + +def test_known_profile_no_overrides_builds_snapshot(): + catalog = _catalog(_gpt4o_profile()) + + snap = resolve_capacity( + model_id="gpt-4o", + provider="openai", + capability_profiles=catalog, + ) + + assert isinstance(snap, ModelCapacitySnapshot) + assert snap.provider == "openai" + assert snap.model_name == "gpt-4o" + assert snap.context_window_tokens == 128_000 + assert snap.max_output_tokens == 16_384 + assert snap.default_output_reserve_tokens == 4_096 + assert snap.requested_output_tokens == 4_096 # defaulted from reserve + assert snap.provider_input_limit_tokens == 128_000 - 4_096 + assert snap.tokenizer_family == "o200k_base" + assert snap.counting_mode == "estimated" # no adapter registered yet + assert snap.capability_profile_version == "openai/gpt-4o@1" + assert snap.resolver_version == RESOLVER_VERSION + assert "capability_profile_missing" not in snap.unknown_capabilities + # Fields the profile defined come from "profile"; fields the profile left + # null are tagged "unknown". None should come from "operator" when no + # overrides are supplied. + assert snap.field_sources["context_window_tokens"] == "profile" + assert snap.field_sources["max_output_tokens"] == "profile" + assert snap.field_sources["max_input_tokens"] == "unknown" # gpt-4o has no separate input limit + assert "operator" not in snap.field_sources.values() + assert len(snap.fingerprint) == 32 + + +def test_operator_override_wins_over_profile(): + catalog = _catalog(_gpt4o_profile()) + + snap = resolve_capacity( + model_id="gpt-4o", + provider="openai", + operator_overrides={"max_output_tokens": 8_192}, + capability_profiles=catalog, + ) + + assert snap.max_output_tokens == 8_192 + assert snap.field_sources["max_output_tokens"] == "operator" + assert snap.field_sources["context_window_tokens"] == "profile" + + +def test_uncataloged_model_with_operator_overrides_resolves(): + snap = resolve_capacity( + model_id="custom-model", + provider="self-hosted", + operator_overrides={ + "context_window_tokens": 32_000, + "max_output_tokens": 4_000, + "default_output_reserve_tokens": 1_000, + }, + capability_profiles={}, + ) + + assert snap.context_window_tokens == 32_000 + assert snap.requested_output_tokens == 1_000 + assert snap.provider_input_limit_tokens == 32_000 - 1_000 + assert snap.field_sources["context_window_tokens"] == "operator" + assert snap.capability_profile_version is None + assert "capability_profile_missing" in snap.unknown_capabilities + + +def test_uncataloged_model_without_hard_capacity_is_rejected(): + with pytest.raises(ProviderCapabilityUnknown): + resolve_capacity( + model_id="ghost-model", + provider="unknown-provider", + capability_profiles={}, + ) + + +def test_max_output_exceeding_context_window_is_rejected(): + bad_profile = CapabilityProfile( + provider="x", model_name="y", capability_profile_version="x/y@1", + window_shape="combined", context_window_tokens=4_096, + max_output_tokens=8_192, default_output_reserve_tokens=1_024, + ) + with pytest.raises(InvalidCapacityConfiguration): + resolve_capacity( + model_id="y", + provider="x", + capability_profiles=_catalog(bad_profile), + ) + + +def test_requested_output_exceeding_max_output_is_rejected(): + catalog = _catalog(_gpt4o_profile()) + with pytest.raises(RequestedOutputExceedsCap): + resolve_capacity( + model_id="gpt-4o", + provider="openai", + requested_output_tokens=32_000, + capability_profiles=catalog, + ) + + +def test_requested_output_defaults_to_profile_reserve(): + catalog = _catalog(_gpt4o_profile()) + snap = resolve_capacity( + model_id="gpt-4o", + provider="openai", + capability_profiles=catalog, + ) + assert snap.requested_output_tokens == 4_096 + + +def test_separate_input_limit_uses_max_input_tokens(): + catalog = _catalog(_separate_limit_profile()) + snap = resolve_capacity( + model_id="separate-limit-model", + provider="testprovider", + capability_profiles=catalog, + ) + assert snap.max_input_tokens == 32_768 + assert snap.provider_input_limit_tokens == 32_768 + + +def test_separate_input_limit_with_combined_takes_minimum(): + profile = CapabilityProfile( + provider="x", model_name="y", capability_profile_version="x/y@1", + window_shape="combined", context_window_tokens=128_000, + max_input_tokens=16_000, max_output_tokens=4_096, + default_output_reserve_tokens=512, + ) + snap = resolve_capacity( + model_id="y", provider="x", + capability_profiles=_catalog(profile), + ) + assert snap.provider_input_limit_tokens == 16_000 + + +def test_snapshot_is_immutable(): + catalog = _catalog(_gpt4o_profile()) + snap = resolve_capacity( + model_id="gpt-4o", provider="openai", + capability_profiles=catalog, + ) + with pytest.raises(ValidationError): + snap.provider = "mutated" + + +def test_fingerprint_recomputes_identically(): + catalog = _catalog(_gpt4o_profile()) + snap = resolve_capacity( + model_id="gpt-4o", provider="openai", + capability_profiles=catalog, + ) + + recomputed = compute_fingerprint( + resolver_version=snap.resolver_version, + provider=snap.provider, + model_name=snap.model_name, + context_window_tokens=snap.context_window_tokens, + max_input_tokens=snap.max_input_tokens, + max_output_tokens=snap.max_output_tokens, + default_output_reserve_tokens=snap.default_output_reserve_tokens, + requested_output_tokens=snap.requested_output_tokens, + provider_input_limit_tokens=snap.provider_input_limit_tokens, + tokenizer_family=snap.tokenizer_family, + counting_mode=snap.counting_mode, + capability_profile_version=snap.capability_profile_version, + unknown_capabilities=snap.unknown_capabilities, + field_sources=dict(snap.field_sources), + ) + + assert snap.fingerprint == recomputed + + +def test_fingerprint_changes_when_request_changes(): + catalog = _catalog(_gpt4o_profile()) + snap_a = resolve_capacity( + model_id="gpt-4o", provider="openai", + requested_output_tokens=2_000, + capability_profiles=catalog, + ) + snap_b = resolve_capacity( + model_id="gpt-4o", provider="openai", + requested_output_tokens=4_000, + capability_profiles=catalog, + ) + assert snap_a.fingerprint != snap_b.fingerprint + + +def test_negative_or_zero_capacity_is_rejected(): + with pytest.raises(InvalidCapacityConfiguration): + resolve_capacity( + model_id="bad", provider="x", + operator_overrides={"context_window_tokens": 0}, + capability_profiles={}, + ) + with pytest.raises(InvalidCapacityConfiguration): + resolve_capacity( + model_id="bad", provider="x", + operator_overrides={"context_window_tokens": -100}, + capability_profiles={}, + ) + + +def test_requested_output_must_be_positive(): + catalog = _catalog(_gpt4o_profile()) + with pytest.raises(InvalidCapacityConfiguration): + resolve_capacity( + model_id="gpt-4o", provider="openai", + requested_output_tokens=0, + capability_profiles=catalog, + ) + + +def test_max_input_tokens_above_context_window_is_rejected(): + with pytest.raises(InvalidCapacityConfiguration) as exc_info: + resolve_capacity( + model_id="bad", provider="x", + operator_overrides={ + "context_window_tokens": 128_000, + "max_input_tokens": 200_000, + }, + capability_profiles={}, + ) + assert "max_input_tokens" in str(exc_info.value) + assert "exceeds context_window_tokens" in str(exc_info.value) + + +def test_max_input_tokens_equal_to_context_window_is_allowed(): + snap = resolve_capacity( + model_id="ok", provider="x", + operator_overrides={ + "context_window_tokens": 128_000, + "max_input_tokens": 128_000, + "max_output_tokens": 4_096, + }, + capability_profiles={}, + ) + assert snap.max_input_tokens == 128_000 + + +def test_unknown_capabilities_includes_tokenizer_when_estimated(): + catalog = _catalog(_gpt4o_profile()) + snap = resolve_capacity( + model_id="gpt-4o", provider="openai", + capability_profiles=catalog, + ) + assert "tokenizer" in snap.unknown_capabilities diff --git a/test/sdk/core/models/test_openai_llm.py b/test/sdk/core/models/test_openai_llm.py index af33cc82a..cd6192b7d 100644 --- a/test/sdk/core/models/test_openai_llm.py +++ b/test/sdk/core/models/test_openai_llm.py @@ -86,11 +86,18 @@ def __repr__(self): smol_mod.Tool = object sys.modules["smolagents"] = smol_mod sys.modules["smolagents.models"] = smol_models + smol_memory = types.ModuleType("smolagents.memory") + smol_memory.ActionStep = type("ActionStep", (), {}) + smol_memory.AgentMemory = type("AgentMemory", (), {}) + smol_memory.MemoryStep = type("MemoryStep", (), {}) + sys.modules["smolagents.memory"] = smol_memory smol_monitoring = types.ModuleType("smolagents.monitoring") + class TokenUsage: def __init__(self, input_tokens=0, output_tokens=0): self.input_tokens = input_tokens self.output_tokens = output_tokens + smol_monitoring.TokenUsage = TokenUsage sys.modules["smolagents.monitoring"] = smol_monitoring @@ -218,6 +225,10 @@ def from_dict(d): mock_models_module.ChatMessage = SimpleChatMessage mock_models_module.MessageRole = MagicMock() mock_smolagents.models = mock_models_module +mock_memory_module = MagicMock() +mock_memory_module.ActionStep = type("ActionStep", (), {}) +mock_memory_module.AgentMemory = type("AgentMemory", (), {}) +mock_memory_module.MemoryStep = type("MemoryStep", (), {}) mock_smolagents_monitoring = types.ModuleType("smolagents.monitoring") @@ -298,6 +309,7 @@ class MockProcessType: module_mocks = { "smolagents": mock_smolagents, "smolagents.models": mock_models_module, + "smolagents.memory": mock_memory_module, "smolagents.monitoring": mock_smolagents_monitoring, "openai.types": MagicMock(), "openai.types.chat": MagicMock(), @@ -1334,6 +1346,259 @@ def test_call_with_token_tracker_uses_provided_tracker(openai_model_instance): mock_tracker.record_token.assert_called() +def _safe_input_budget_snapshot(requested_output_tokens=128): + payload = { + "w1_fingerprint": "w1fingerprint", + "provider": "openai", + "model_name": "gpt-test", + "requested_output_tokens": requested_output_tokens, + "output_reserve_source": "model_default", + "provider_input_limit_tokens": 1000, + "uncertainty_reserve_tokens": 0, + "uncertainty_reserve_basis": "none", + "approved_profile_reserve_tokens": None, + "soft_limit_ratio": 0.8, + "soft_limit_ratio_source": "code_default", + "soft_input_budget_tokens": 800, + "hard_input_budget_tokens": 1000, + "field_sources": {}, + "warnings": [], + "resolver_version": "1.0.0", + } + payload["fingerprint"] = openai_llm_module.compute_w2_fingerprint( + w2_resolver_version=payload["resolver_version"], + w1_fingerprint=payload["w1_fingerprint"], + provider=payload["provider"], + model_name=payload["model_name"], + requested_output_tokens=payload["requested_output_tokens"], + output_reserve_source=payload["output_reserve_source"], + uncertainty_reserve_tokens=payload["uncertainty_reserve_tokens"], + uncertainty_reserve_basis=payload["uncertainty_reserve_basis"], + approved_profile_reserve_tokens=payload["approved_profile_reserve_tokens"], + soft_limit_ratio=payload["soft_limit_ratio"], + soft_limit_ratio_source=payload["soft_limit_ratio_source"], + soft_input_budget_tokens=payload["soft_input_budget_tokens"], + hard_input_budget_tokens=payload["hard_input_budget_tokens"], + field_sources=payload["field_sources"], + warnings=payload["warnings"], + ) + return payload + + +def test_call_with_snapshot_does_not_autofill_max_tokens_from_max_output_tokens( + openai_model_instance, +): + """Regression: when a W2 snapshot is active on self, __call__ must not + auto-fill max_tokens from self.max_output_tokens. The dispatch boundary + treats any caller-supplied max_tokens that disagrees with the snapshot as + CallerMaxTokensOverrideForbidden, so the pre-W2 auto-fill must be gated + on the snapshot being absent. + """ + snapshot = _safe_input_budget_snapshot(requested_output_tokens=8192) + openai_model_instance.max_output_tokens = 131072 + openai_model_instance.safe_input_budget_snapshot = snapshot + + messages = [{"role": "user", "content": [{"text": "Hi"}]}] + + mock_chunk = MagicMock() + mock_chunk.choices = [MagicMock()] + mock_chunk.choices[0].delta.content = "ok" + mock_chunk.choices[0].delta.role = "assistant" + mock_chunk.usage = MagicMock() + mock_chunk.usage.prompt_tokens = 1 + mock_chunk.usage.total_tokens = 2 + mock_chunk.usage.completion_tokens = 1 + mock_stream = [mock_chunk] + + mock_result_message = MagicMock() + mock_result_message.raw = mock_stream + mock_result_message.role = MagicMock() + + with patch.object( + openai_model_instance, "_prepare_completion_kwargs", return_value={} + ), patch.object( + mock_models_module.ChatMessage, "from_dict", return_value=mock_result_message + ): + openai_model_instance.client.chat.completions.create.return_value = mock_stream + openai_model_instance.__call__(messages) + + create_kwargs = openai_model_instance.client.chat.completions.create.call_args.kwargs + assert create_kwargs["max_tokens"] == 8192 + + +def test_dispatch_without_w2_snapshot_preserves_existing_max_tokens(openai_model_instance): + openai_model_instance._dispatch_chat_completion( + stream=True, + messages=[], + max_tokens=64, + ) + + openai_model_instance.client.chat.completions.create.assert_called_once_with( + stream=True, + messages=[], + max_tokens=64, + ) + + +def test_dispatch_with_w2_snapshot_sets_requested_output_tokens(openai_model_instance): + openai_model_instance._dispatch_chat_completion( + safe_input_budget_snapshot=_safe_input_budget_snapshot(256), + stream=True, + messages=[], + ) + + openai_model_instance.client.chat.completions.create.assert_called_once_with( + stream=True, + messages=[], + max_tokens=256, + ) + + +def test_dispatch_with_matching_caller_max_tokens_is_allowed(openai_model_instance): + openai_model_instance._dispatch_chat_completion( + safe_input_budget_snapshot=_safe_input_budget_snapshot(256), + stream=True, + messages=[], + max_tokens=256, + ) + + openai_model_instance.client.chat.completions.create.assert_called_once_with( + stream=True, + messages=[], + max_tokens=256, + ) + + +def test_dispatch_rejects_caller_max_tokens_override(openai_model_instance): + with pytest.raises(openai_llm_module.CallerMaxTokensOverrideForbidden): + openai_model_instance._dispatch_chat_completion( + safe_input_budget_snapshot=_safe_input_budget_snapshot(256), + stream=True, + messages=[], + max_tokens=128, + ) + + openai_model_instance.client.chat.completions.create.assert_not_called() + + +def test_dispatch_rejects_tampered_w2_snapshot(openai_model_instance): + snapshot = _safe_input_budget_snapshot(256) + snapshot["hard_input_budget_tokens"] = 999 + + with pytest.raises(openai_llm_module.SafeInputBudgetFingerprintMismatch): + openai_model_instance._dispatch_chat_completion( + safe_input_budget_snapshot=snapshot, + stream=True, + messages=[], + ) + + openai_model_instance.client.chat.completions.create.assert_not_called() + + +def _matching_capacity_snapshot(budget_snapshot): + return { + "provider": budget_snapshot["provider"], + "model_name": budget_snapshot["model_name"], + "capacity_fingerprint": budget_snapshot["w1_fingerprint"], + } + + +def test_dispatch_accepts_matching_w1_capacity_snapshot(openai_model_instance): + snapshot = _safe_input_budget_snapshot(256) + openai_model_instance._dispatch_chat_completion( + safe_input_budget_snapshot=snapshot, + capacity_snapshot=_matching_capacity_snapshot(snapshot), + stream=True, + messages=[], + ) + + openai_model_instance.client.chat.completions.create.assert_called_once_with( + stream=True, + messages=[], + max_tokens=256, + ) + + +def test_dispatch_rejects_stale_w1_fingerprint(openai_model_instance): + snapshot = _safe_input_budget_snapshot(256) + capacity = _matching_capacity_snapshot(snapshot) + capacity["capacity_fingerprint"] = "different-w1-fingerprint" + + with pytest.raises(openai_llm_module.SafeInputBudgetCapacityMismatch) as exc_info: + openai_model_instance._dispatch_chat_completion( + safe_input_budget_snapshot=snapshot, + capacity_snapshot=capacity, + stream=True, + messages=[], + ) + + assert exc_info.value.field == "w1_fingerprint" + openai_model_instance.client.chat.completions.create.assert_not_called() + + +def test_dispatch_rejects_cross_provider_w2_snapshot(openai_model_instance): + snapshot = _safe_input_budget_snapshot(256) + capacity = _matching_capacity_snapshot(snapshot) + capacity["provider"] = "dashscope" + + with pytest.raises(openai_llm_module.SafeInputBudgetCapacityMismatch) as exc_info: + openai_model_instance._dispatch_chat_completion( + safe_input_budget_snapshot=snapshot, + capacity_snapshot=capacity, + stream=True, + messages=[], + ) + + assert exc_info.value.field == "provider" + openai_model_instance.client.chat.completions.create.assert_not_called() + + +def test_dispatch_rejects_cross_model_w2_snapshot(openai_model_instance): + snapshot = _safe_input_budget_snapshot(256) + capacity = _matching_capacity_snapshot(snapshot) + capacity["model_name"] = "gpt-other" + + with pytest.raises(openai_llm_module.SafeInputBudgetCapacityMismatch) as exc_info: + openai_model_instance._dispatch_chat_completion( + safe_input_budget_snapshot=snapshot, + capacity_snapshot=capacity, + stream=True, + messages=[], + ) + + assert exc_info.value.field == "model_name" + openai_model_instance.client.chat.completions.create.assert_not_called() + + +def test_dispatch_skips_w1_w2_consistency_when_capacity_snapshot_absent(openai_model_instance): + snapshot = _safe_input_budget_snapshot(256) + + openai_model_instance._dispatch_chat_completion( + safe_input_budget_snapshot=snapshot, + capacity_snapshot=None, + stream=True, + messages=[], + ) + + openai_model_instance.client.chat.completions.create.assert_called_once_with( + stream=True, + messages=[], + max_tokens=256, + ) + + +def test_safe_input_budget_trace_attributes_are_prefixed(): + attrs = ImportedOpenAIModel._safe_input_budget_trace_attributes( + _safe_input_budget_snapshot(256) + ) + + assert len(attrs["w2.budget_fingerprint"]) == 32 + assert attrs["w2.w1_fingerprint"] == "w1fingerprint" + assert attrs["w2.requested_output_tokens"] == 256 + assert attrs["w2.soft_input_budget_tokens"] == 800 + assert attrs["w2.hard_input_budget_tokens"] == 1000 + + def test_call_without_tracker_creates_tracker(openai_model_instance): """When no _token_tracker is passed, __call__ creates one from monitoring manager.""" mock_tracker = MagicMock() diff --git a/test/sdk/monitor/test_monitoring.py b/test/sdk/monitor/test_monitoring.py index c3c5a7ad0..e88632348 100644 --- a/test/sdk/monitor/test_monitoring.py +++ b/test/sdk/monitor/test_monitoring.py @@ -26,6 +26,8 @@ get_monitoring_buffer, set_monitoring_context, get_monitoring_context, + set_monitoring_capacity_snapshot, + set_monitoring_safe_input_budget_snapshot, get_agent_monitoring_context, agent_monitoring_context, _monitoring_buffer, @@ -1388,6 +1390,43 @@ def test_all_valid_records(self): assert mock_session.add.call_count == 3 + def test_capacity_snapshot_fields_pass_to_model_monitoring_record(self): + """Capacity snapshot fields are persisted through the ORM row payload.""" + mock_session_fn, mock_model_monitoring_record = self._setup_db_mocks() + mock_session = MagicMock() + mock_session_fn.return_value.__enter__ = Mock(return_value=mock_session) + mock_session_fn.return_value.__exit__ = Mock(return_value=None) + + buf = self._make_buffer() + record = { + "model_name": "m1", + "tenant_id": "t1", + "context_window_tokens": 128000, + "default_output_reserve_tokens": 1024, + "capability_profile_version": "openai/gpt-4o@1", + "capacity_source": "profile", + "requested_output_tokens": 1024, + "provider_input_limit_tokens": 126976, + "tokenizer_family": "o200k_base", + "counting_mode": "exact", + "unknown_capabilities": ["prompt_cache"], + "capacity_fingerprint": "abc123", + "budget_fingerprint": "w2abc", + "budget_w1_fingerprint": "abc123", + "budget_requested_output_tokens": 1024, + "budget_output_reserve_source": "model_default", + "budget_provider_input_limit_tokens": 126976, + "budget_uncertainty_reserve_tokens": 0, + "budget_uncertainty_reserve_basis": "none", + "budget_soft_limit_ratio": 0.8, + "budget_soft_input_budget_tokens": 101580, + "budget_hard_input_budget_tokens": 126976, + "budget_warnings": [], + } + buf._write_batch([record]) + + mock_model_monitoring_record.assert_called_once_with(**record) + def test_all_invalid_records(self): """When every record fails, _write_batch still does not raise.""" mock_session_fn, _ = self._setup_db_mocks() @@ -1415,6 +1454,8 @@ def setup_method(self): _mod._monitoring_user_id.set(None) _mod._monitoring_agent_id.set(None) _mod._monitoring_conversation_id.set(None) + _mod._monitoring_capacity_snapshot.set(None) + _mod._monitoring_safe_input_budget_snapshot.set(None) def test_enqueue_with_tenant_id(self): """Record is added to buffer when tenant_id is present.""" @@ -1497,6 +1538,128 @@ def test_snapshot_priority_over_live_context(self): record = mock_buffer.add_record.call_args[0][0] assert record["tenant_id"] == "from-snapshot" + def test_capacity_snapshot_fields_are_enqueued(self): + """Resolved capacity snapshot fields are copied to LLM monitoring rows.""" + mock_buffer = MagicMock() + mock_buffer.is_enabled = True + + tracker = MagicMock() + tracker.start_time = time.time() + tracker.first_token_time = None + tracker.input_tokens = 12 + tracker.output_tokens = 5 + tracker.token_count = 5 + tracker._context_snapshot = {"tenant_id": "t-1"} + tracker._display_name = None + + set_monitoring_capacity_snapshot({ + "context_window_tokens": 128000, + "default_output_reserve_tokens": 1024, + "capability_profile_version": "openai/gpt-4o@1", + "field_sources": { + "context_window_tokens": "profile", + "max_output_tokens": "operator", + }, + "requested_output_tokens": 1024, + "provider_input_limit_tokens": 127000, + "tokenizer_family": "o200k_base", + "counting_mode": "exact", + "unknown_capabilities": ["prompt_cache"], + "fingerprint": "abc123", + }) + + with patch( + "sdk.nexent.monitor.monitoring.get_monitoring_buffer", + return_value=mock_buffer, + ): + _enqueue_monitoring_record(tracker, "model-a", "op", {}) + + record = mock_buffer.add_record.call_args[0][0] + assert record["context_window_tokens"] == 128000 + assert record["default_output_reserve_tokens"] == 1024 + assert record["capability_profile_version"] == "openai/gpt-4o@1" + assert record["capacity_source"] == "operator" + assert record["requested_output_tokens"] == 1024 + assert record["provider_input_limit_tokens"] == 127000 + assert record["tokenizer_family"] == "o200k_base" + assert record["counting_mode"] == "exact" + assert record["unknown_capabilities"] == ["prompt_cache"] + assert record["capacity_fingerprint"] == "abc123" + + def test_safe_input_budget_snapshot_fields_are_enqueued(self): + """Resolved W2 budget snapshot fields are copied to LLM monitoring rows.""" + mock_buffer = MagicMock() + mock_buffer.is_enabled = True + + tracker = MagicMock() + tracker.start_time = time.time() + tracker.first_token_time = None + tracker.input_tokens = 12 + tracker.output_tokens = 5 + tracker.token_count = 5 + tracker._context_snapshot = {"tenant_id": "t-1"} + tracker._display_name = None + + set_monitoring_safe_input_budget_snapshot({ + "fingerprint": "w2abc", + "w1_fingerprint": "w1abc", + "requested_output_tokens": 1024, + "output_reserve_source": "model_default", + "provider_input_limit_tokens": 127000, + "uncertainty_reserve_tokens": 12800, + "uncertainty_reserve_basis": "context_window_10pct", + "soft_limit_ratio": 0.8, + "soft_input_budget_tokens": 91360, + "hard_input_budget_tokens": 114200, + "warnings": ["uncertainty_reserve_active"], + }) + + with patch( + "sdk.nexent.monitor.monitoring.get_monitoring_buffer", + return_value=mock_buffer, + ): + _enqueue_monitoring_record(tracker, "model-a", "op", {}) + + record = mock_buffer.add_record.call_args[0][0] + assert record["budget_fingerprint"] == "w2abc" + assert record["budget_w1_fingerprint"] == "w1abc" + assert record["budget_requested_output_tokens"] == 1024 + assert record["budget_output_reserve_source"] == "model_default" + assert record["budget_provider_input_limit_tokens"] == 127000 + assert record["budget_uncertainty_reserve_tokens"] == 12800 + assert record["budget_uncertainty_reserve_basis"] == "context_window_10pct" + assert record["budget_soft_limit_ratio"] == 0.8 + assert record["budget_soft_input_budget_tokens"] == 91360 + assert record["budget_hard_input_budget_tokens"] == 114200 + assert record["budget_warnings"] == ["uncertainty_reserve_active"] + + def test_absent_capacity_snapshot_does_not_add_fields(self): + """Records remain valid when no capacity snapshot is bound.""" + mock_buffer = MagicMock() + mock_buffer.is_enabled = True + + tracker = MagicMock() + tracker.start_time = time.time() + tracker.first_token_time = None + tracker.input_tokens = 0 + tracker.output_tokens = 0 + tracker.token_count = 0 + tracker._context_snapshot = {"tenant_id": "t-1"} + tracker._display_name = None + + set_monitoring_capacity_snapshot(None) + + with patch( + "sdk.nexent.monitor.monitoring.get_monitoring_buffer", + return_value=mock_buffer, + ): + _enqueue_monitoring_record(tracker, "model-a", "op", {}) + + record = mock_buffer.add_record.call_args[0][0] + assert "capacity_fingerprint" not in record + assert "provider_input_limit_tokens" not in record + assert "budget_fingerprint" not in record + # ========================================================================= # TestRecordModelCallContext (Task 4.1) @@ -1681,6 +1844,8 @@ def setup_method(self): _mod._monitoring_conversation_id.set(None) _mod._monitoring_operation.set("unknown") _mod._monitoring_display_name.set("TestModel") + _mod._monitoring_capacity_snapshot.set(None) + _mod._monitoring_safe_input_budget_snapshot.set(None) def _make_monitored_client(self): mock_original = MagicMock() @@ -1817,6 +1982,7 @@ def setup_method(self): _mod._monitoring_conversation_id.set(99) _mod._monitoring_operation.set("title_generation") _mod._monitoring_display_name.set("MyModel") + _mod._monitoring_capacity_snapshot.set(None) def test_full_record_fields(self): mock_buffer = MagicMock() @@ -1853,6 +2019,74 @@ def test_full_record_fields(self): assert record["conversation_id"] == 99 assert record["display_name"] == "MyModel" + def test_client_record_includes_capacity_snapshot_fields(self): + mock_buffer = MagicMock() + mock_buffer.is_enabled = True + set_monitoring_capacity_snapshot({ + "capacity_source": "profile", + "requested_output_tokens": 2048, + "provider_input_limit_tokens": 30000, + "counting_mode": "estimated", + "capacity_fingerprint": "def456", + }) + + with patch("sdk.nexent.monitor.monitoring.get_monitoring_buffer", return_value=mock_buffer): + _enqueue_client_monitoring_record( + model_name="test-model", + model_type="llm", + request_duration_ms=500, + ttft_ms=0, + input_tokens=10, + output_tokens=20, + total_tokens=30, + generation_rate=0.0, + is_streaming=False, + ) + + record = mock_buffer.add_record.call_args[0][0] + assert record["capacity_source"] == "profile" + assert record["requested_output_tokens"] == 2048 + assert record["provider_input_limit_tokens"] == 30000 + assert record["counting_mode"] == "estimated" + assert record["capacity_fingerprint"] == "def456" + + def test_client_record_includes_safe_input_budget_snapshot_fields(self): + mock_buffer = MagicMock() + mock_buffer.is_enabled = True + set_monitoring_safe_input_budget_snapshot({ + "fingerprint": "w2def", + "w1_fingerprint": "def456", + "requested_output_tokens": 2048, + "output_reserve_source": "agent", + "provider_input_limit_tokens": 30000, + "uncertainty_reserve_tokens": 0, + "uncertainty_reserve_basis": "none", + "soft_limit_ratio": 0.75, + "soft_input_budget_tokens": 22500, + "hard_input_budget_tokens": 30000, + }) + + with patch("sdk.nexent.monitor.monitoring.get_monitoring_buffer", return_value=mock_buffer): + _enqueue_client_monitoring_record( + model_name="test-model", + model_type="llm", + request_duration_ms=500, + ttft_ms=0, + input_tokens=10, + output_tokens=20, + total_tokens=30, + generation_rate=0.0, + is_streaming=False, + ) + + record = mock_buffer.add_record.call_args[0][0] + assert record["budget_fingerprint"] == "w2def" + assert record["budget_w1_fingerprint"] == "def456" + assert record["budget_requested_output_tokens"] == 2048 + assert record["budget_output_reserve_source"] == "agent" + assert record["budget_soft_input_budget_tokens"] == 22500 + assert record["budget_hard_input_budget_tokens"] == 30000 + def test_error_record(self): mock_buffer = MagicMock() mock_buffer.is_enabled = True