diff --git a/.env.example b/.env.example index ffb264e..de3f0dc 100644 --- a/.env.example +++ b/.env.example @@ -26,8 +26,9 @@ REDIS_PASSWORD=your_redis_password # 安全配置 (Security) # ========================================== # 【关键】用于解密从数据库 llm_user_config 表取出的 API Key 密文 -# 需要与 Java 管理端的加密 Secret / 离散算法保持完全一致 -API_KEY_ENCRYPTION_SECRET=your_encryption_secret +# 需要与 Java 管理端的 AES/GCM Secret 保持完全一致:64 位 hex,解码后 32 字节 +# 下方为本地占位示例,生产必须覆盖 +API_KEY_ENCRYPTION_SECRET=0000000000000000000000000000000000000000000000000000000000000000 # ========================================== # 系统级兜底 LLM 配置 (Platform Default Fallback LLMs) diff --git a/.gitignore b/.gitignore index cfe5da4..cd7bfd2 100644 --- a/.gitignore +++ b/.gitignore @@ -77,6 +77,9 @@ tests/api_keys.json /requirements +# 本地 Docker Compose 覆盖(KRaft Kafka、无认证),不入 Git +docker-compose.override.yml + /.specs/ /.kiro/ /blog/ diff --git a/docs/api/schemas/mysql.md b/docs/api/schemas/mysql.md index 15dd669..617d564 100644 --- a/docs/api/schemas/mysql.md +++ b/docs/api/schemas/mysql.md @@ -10,12 +10,12 @@ ORM 与 migration 不一致时,以 migration 为准并修正 ORM;scripts/db/ ## 表清单 -按业务域共 12 张表: +按业务域共 14 张表: | 业务域 | 表 | 主键 ID 起始 | | --- | --- | --- | | [用户](#1-用户) | `sys_user` | 10000 | -| [LLM 配置与用量](#2-llm-配置与用量) | `llm_system_provider`, `llm_user_config`, `llm_usage_log` | 10000 | +| [LLM 配置与用量](#2-llm-配置与用量) | `llm_system_provider`, `llm_provider_model`, `llm_system_preset`, `llm_user_config`, `llm_usage_log` | 10000 | | [数据集与对话](#3-数据集与对话) | `dataset`, `chat_conversation`, `chat_message` | 10000 | | [文档解析](#4-文档解析) | `document_original_file`, `document_parse_file`, `document_parsed_log`, `document_parse_pipeline` | 10000 | | [知识索引](#5-知识索引) | `kb_document_chunk` | 10000 | @@ -60,14 +60,47 @@ ORM:[`SystemProviderDB`](../../src/models/db_models.py) | `provider_type` | VARCHAR(32) UNIQUE | `openai` / `claude` / `glm` / `deepseek` 等 | | `provider_name` | VARCHAR(64) | 厂商展示名 | | `api_base_url` | VARCHAR(512) | 官方默认 API 地址 | -| `supported_models` | JSON | 支持模型与能力映射 | -| `config_schema` | JSON | 配置参数 Schema | | `is_active` | BOOLEAN | 是否启用 | | `priority` | INT | 厂商优先级(1-100),默认 50 | | `created_at` / `updated_at` | DATETIME | 创建 / 更新时间 | 索引:`uk_provider_type`。 +### `llm_provider_model` — 厂商模型能力目录 + +ORM:[`ProviderModelDB`](../../src/models/db_models.py) + +| 字段 | 类型 | 说明 | +| --- | --- | --- | +| `id` | BIGINT UNSIGNED PK | 主键 | +| `provider_id` | BIGINT UNSIGNED | 关联 `llm_system_provider.id` | +| `model_name` | VARCHAR(128) | 模型名 | +| `capability` | VARCHAR(32) | 单能力;一模型多能力拆成多行 | +| `is_active` | BOOLEAN | 该模型能力是否上架 | +| `created_at` / `updated_at` | DATETIME | 创建 / 更新时间 | + +索引: +- `uk_provider_model_cap(provider_id, model_name, capability)` +- `idx_provider_cap(provider_id, capability)` + +### `llm_system_preset` — 系统预设模板 + +ORM:[`SystemPresetDB`](../../src/models/db_models.py) + +| 字段 | 类型 | 说明 | +| --- | --- | --- | +| `id` | BIGINT UNSIGNED PK | 主键 | +| `provider_id` | BIGINT UNSIGNED | 关联 `llm_system_provider.id` | +| `model_name` | VARCHAR(128) | 模型名 | +| `capability` | VARCHAR(32) | 能力标识 | +| `api_key` | VARCHAR(512) | 平台 Key,**加密存储** | +| `is_active` | BOOLEAN | 是否对新用户下发 | +| `created_at` / `updated_at` | DATETIME | 创建 / 更新时间 | + +索引:`uk_preset_provider_model_cap(provider_id, model_name, capability)`。 + +说明:Python 运行时不直接读取本表决定生效配置;Java 注册时会将 active 预设复制进 `llm_user_config`。 + ### `llm_user_config` — 用户级 LLM 配置 ORM:[`UserLLMConfigDB`](../../src/models/db_models.py) @@ -77,28 +110,32 @@ ORM:[`UserLLMConfigDB`](../../src/models/db_models.py) | `id` | BIGINT UNSIGNED PK | 配置唯一标识 | | `user_id` | BIGINT UNSIGNED | 所属用户 | | `provider_id` | BIGINT UNSIGNED | 关联 `llm_system_provider.id` | -| `provider_type` | VARCHAR(32) | 厂商类型快照 | -| `provider_name` | VARCHAR(64) | 厂商名快照 | -| `config_name` | VARCHAR(64) | 用户自定义配置名 | +| `provider_type` | VARCHAR(32) | 厂商类型快照,用于下游路由到对应 SDK | | `api_key` | VARCHAR(512) | **加密存储**,由 `API_KEY_ENCRYPTION_SECRET` 解密 | -| `custom_api_base_url` | VARCHAR(512) | 自定义 API 地址 | +| `api_base_url` | VARCHAR(512) | 实际生效地址 | | `model_name` | VARCHAR(128) | 具体模型名 | -| `priority` | INT | 优先级 1-100 | -| `is_active` | BOOLEAN | 是否启用 | -| `is_default` | BOOLEAN | 是否默认配置 | -| `timeout_ms` | INT | 超时(毫秒),默认 60000 | -| `max_retries` | INT | 最大重试次数,默认 3 | -| `stream_enabled` | BOOLEAN | 是否支持流式输出 | -| `capability` | VARCHAR(32) | `CHAT` / `EMBEDDING` / `RERANK` / `OCR`,默认 `CHAT` | -| `default_marker` | INT,生成列 | `default+active` 时为 `1`,否则 `NULL`,仅用于唯一约束(应用层不写入) | -| `extra_config` | JSON | 扩展配置 | +| `capability` | VARCHAR(32) | `CHAT` / `EMBEDDING` / `RERANK` / `OCR` / `VISION` 等,默认 `CHAT` | +| `is_active` | BOOLEAN | 模型启停 + 生效过滤 | +| `is_default` | BOOLEAN | 该能力是否生效 | +| `is_system_preset` | BOOLEAN | 是否系统预设行 | | `created_at` / `updated_at` | DATETIME | 创建 / 更新时间 | 索引: -- `uk_user_provider_model(user_id, provider_id, model_name)` +- `uk_user_provider_model_capability(user_id, provider_id, model_name, capability, is_system_preset)` - `idx_user_active_default(user_id, is_active, is_default)` - `idx_user_provider_cap(user_id, provider_type, capability)` -- `uq_user_default_per_capability(user_id, provider_type, capability, default_marker)` — 唯一键。借助生成列 `default_marker`(默认+启用时为 1,否则 NULL)与 MySQL「唯一索引中 NULL 不计重复」语义,保证每个 `(user_id, provider_type, capability)` 至多一条默认且启用的配置;非默认/停用配置不受限(迁移 0012) + +运行时读取生效配置: + +```sql +SELECT * +FROM llm_user_config +WHERE user_id = :user_id + AND capability = :capability + AND is_default = TRUE + AND is_active = TRUE +LIMIT 1; +``` ### `llm_usage_log` — LLM 调用用量日志 diff --git a/docs/ops/configure.md b/docs/ops/configure.md index 40738d8..a0723d1 100644 --- a/docs/ops/configure.md +++ b/docs/ops/configure.md @@ -30,7 +30,7 @@ | --- | --- | | `DB_HOST` / `DB_PORT` / `DB_USER` / `DB_PASSWORD` / `DB_NAME` | MySQL 连接 | | `REDIS_HOST` / `REDIS_PORT` | Redis 连接 | -| `API_KEY_ENCRYPTION_SECRET` | API Key 加密 Secret,必须与 Java 管理端一致 | +| `API_KEY_ENCRYPTION_SECRET` | API Key 加密 Secret,必须与 Java 管理端一致;64 位 hex,解码后 32 字节,用于 AES-256-GCM | | `SYSTEM_LLM_PROVIDER` / `SYSTEM_LLM_API_KEY` / `SYSTEM_LLM_API_BASE` | 系统级兜底 LLM | | `KAFKA_BOOTSTRAP_SERVERS` 等(若 `MQ_VENDOR=kafka`) | Kafka 接入信息 | | `MINIO_*`(若 `STORAGE_TYPE=minio`) | 对象存储凭据 | diff --git a/docs/ops/deploy.md b/docs/ops/deploy.md index 6b14646..f04a23d 100644 --- a/docs/ops/deploy.md +++ b/docs/ops/deploy.md @@ -57,7 +57,7 @@ uvicorn src.main:app --host 0.0.0.0 --port 8000 常见失败: - **应用启动卡在 Kafka**:通常是 `KAFKA_BOOTSTRAP_SERVERS` 配置错或 broker 未起来。本地用 docker-compose 时此地址应为 `127.0.0.1:9092`(容器内部连接用 `tolink-kafka:29092`)。 -- **API 调用 LLM 报解密失败**:`API_KEY_ENCRYPTION_SECRET` 必须与 Java 管理端的加密 Secret 一致,否则 `llm_user_config` 表中的密文无法解密。 +- **API 调用 LLM 报解密失败**:`API_KEY_ENCRYPTION_SECRET` 必须与 Java 管理端的加密 Secret 一致,格式为 64 位 hex(解码后 32 字节),否则 `llm_user_config` 表中的密文无法解密。 - **解析任务消费不到**:检查 `INIT_KAFKA_TOPICS_ON_STARTUP` 是否被关闭,且 topic(`PARSE_TASK_TOPIC` 默认 `tolink-document-pares`)是否已存在。 ## 生产部署注意事项 diff --git a/migrations/versions/0013_20260606_llm_config_refactor.py b/migrations/versions/0013_20260606_llm_config_refactor.py new file mode 100644 index 0000000..f1340f8 --- /dev/null +++ b/migrations/versions/0013_20260606_llm_config_refactor.py @@ -0,0 +1,253 @@ +"""adapt LLM config tables to provider-model-preset structure + +Java now owns LLM configuration management and writes the effective runtime +configuration into ``llm_user_config``. Align Python's schema chain with that +contract: providers are slimmed down, model capabilities move to a catalog +table, presets are templates copied into user config rows, and user config rows +no longer carry execution parameters or display-only fields. + +Revision ID: 0013 +Revises: 0012 +Create Date: 2026-06-06 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import mysql + +revision: str = "0013" +down_revision: Union[str, None] = "0012" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "llm_provider_model", + sa.Column("id", mysql.BIGINT(unsigned=True), primary_key=True, autoincrement=True), + sa.Column( + "provider_id", + mysql.BIGINT(unsigned=True), + nullable=False, + comment="关联 llm_system_provider.id", + ), + sa.Column("model_name", sa.String(length=128), nullable=False, comment="模型名"), + sa.Column( + "capability", sa.String(length=32), nullable=False, comment="单能力;一模型多能力=多行" + ), + sa.Column( + "is_active", + sa.Boolean(), + nullable=False, + server_default=sa.text("1"), + comment="该模型能力是否上架", + ), + sa.Column( + "created_at", sa.DateTime(), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP") + ), + sa.Column( + "updated_at", sa.DateTime(), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP") + ), + sa.UniqueConstraint( + "provider_id", "model_name", "capability", name="uk_provider_model_cap" + ), + sa.Index("idx_provider_cap", "provider_id", "capability"), + mysql_engine="InnoDB", + mysql_charset="utf8mb4", + mysql_collate="utf8mb4_unicode_ci", + mysql_comment="厂商模型能力目录表", + ) + op.create_table( + "llm_system_preset", + sa.Column("id", mysql.BIGINT(unsigned=True), primary_key=True, autoincrement=True), + sa.Column( + "provider_id", + mysql.BIGINT(unsigned=True), + nullable=False, + comment="关联 llm_system_provider.id", + ), + sa.Column("model_name", sa.String(length=128), nullable=False, comment="模型名"), + sa.Column("capability", sa.String(length=32), nullable=False, comment="能力标识"), + sa.Column("api_key", sa.String(length=512), nullable=False, comment="平台 Key(加密)"), + sa.Column( + "is_active", + sa.Boolean(), + nullable=False, + server_default=sa.text("1"), + comment="是否对新用户下发", + ), + sa.Column( + "created_at", sa.DateTime(), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP") + ), + sa.Column( + "updated_at", sa.DateTime(), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP") + ), + sa.UniqueConstraint( + "provider_id", "model_name", "capability", name="uk_preset_provider_model_cap" + ), + mysql_engine="InnoDB", + mysql_charset="utf8mb4", + mysql_collate="utf8mb4_unicode_ci", + mysql_comment="系统预设表", + ) + + op.drop_column("llm_system_provider", "supported_models") + op.drop_column("llm_system_provider", "config_schema") + + op.drop_constraint("uq_user_default_per_capability", "llm_user_config", type_="unique") + op.drop_column("llm_user_config", "default_marker") + + op.alter_column( + "llm_user_config", + "custom_api_base_url", + existing_type=sa.String(length=512), + new_column_name="api_base_url", + existing_nullable=True, + comment="实际生效地址:用户自定义或厂商默认", + ) + op.add_column( + "llm_user_config", + sa.Column( + "is_system_preset", + sa.Boolean(), + nullable=False, + server_default=sa.text("0"), + comment="系统预设行(只读)", + ), + ) + + op.drop_column("llm_user_config", "provider_name") + op.drop_column("llm_user_config", "config_name") + op.drop_column("llm_user_config", "priority") + op.drop_column("llm_user_config", "timeout_ms") + op.drop_column("llm_user_config", "max_retries") + op.drop_column("llm_user_config", "stream_enabled") + op.drop_column("llm_user_config", "extra_config") + + op.drop_constraint("uk_user_provider_model", "llm_user_config", type_="unique") + op.create_unique_constraint( + "uk_user_provider_model_capability", + "llm_user_config", + ["user_id", "provider_id", "model_name", "capability", "is_system_preset"], + ) + + +def downgrade() -> None: + op.drop_constraint( + "uk_user_provider_model_capability", + "llm_user_config", + type_="unique", + ) + op.create_unique_constraint( + "uk_user_provider_model", + "llm_user_config", + ["user_id", "provider_id", "model_name"], + ) + + op.add_column( + "llm_user_config", + sa.Column("extra_config", mysql.JSON(), nullable=True, comment="扩展配置"), + ) + op.add_column( + "llm_user_config", + sa.Column( + "stream_enabled", + sa.Boolean(), + nullable=True, + server_default=sa.text("1"), + comment="是否支持流式输出", + ), + ) + op.add_column( + "llm_user_config", + sa.Column( + "max_retries", + sa.Integer(), + nullable=True, + server_default=sa.text("3"), + comment="最大重试次数", + ), + ) + op.add_column( + "llm_user_config", + sa.Column( + "timeout_ms", + sa.Integer(), + nullable=True, + server_default=sa.text("60000"), + comment="超时时间(毫秒)", + ), + ) + op.add_column( + "llm_user_config", + sa.Column( + "priority", + sa.Integer(), + nullable=False, + server_default=sa.text("50"), + comment="优先级 1-100", + ), + ) + op.add_column( + "llm_user_config", + sa.Column( + "config_name", + sa.String(length=64), + nullable=False, + server_default="", + comment="用户自定义配置名称", + ), + ) + op.add_column( + "llm_user_config", + sa.Column( + "provider_name", + sa.String(length=64), + nullable=False, + server_default="", + comment="厂商名称快照", + ), + ) + op.drop_column("llm_user_config", "is_system_preset") + op.alter_column( + "llm_user_config", + "api_base_url", + existing_type=sa.String(length=512), + new_column_name="custom_api_base_url", + existing_nullable=True, + comment="自定义 API 地址", + ) + op.add_column( + "llm_user_config", + sa.Column( + "default_marker", + sa.Integer(), + sa.Computed( + "(CASE WHEN is_default = 1 AND is_active = 1 THEN 1 ELSE NULL END)", + persisted=True, + ), + nullable=True, + comment="默认判别生成列:default+active 时为 1,否则 NULL,仅用于唯一约束", + ), + ) + op.create_unique_constraint( + "uq_user_default_per_capability", + "llm_user_config", + ["user_id", "provider_type", "capability", "default_marker"], + ) + + op.add_column( + "llm_system_provider", + sa.Column("config_schema", mysql.JSON(), nullable=True, comment="配置参数 Schema"), + ) + op.add_column( + "llm_system_provider", + sa.Column("supported_models", mysql.JSON(), nullable=True, comment="支持模型与能力映射"), + ) + + op.drop_table("llm_system_preset") + op.drop_table("llm_provider_model") diff --git a/scripts/db/init.sql b/scripts/db/init.sql index b39e2e0..5cf3dad 100644 --- a/scripts/db/init.sql +++ b/scripts/db/init.sql @@ -8,7 +8,7 @@ -- - schema 演进的唯一权威源是 src/models/**.py + migrations/versions/*.py; -- - 修改字段必须先改 ORM 模型并新增 migration,再同步本文件。 -- 同步时机:每条会改动表结构的 migration 落库时一并更新本文件。 --- 末次同步:migration 0011_20260530_add_java_soft_delete_columns +-- 末次同步:migration 0013_20260606_llm_config_refactor -- =============================================== CREATE DATABASE IF NOT EXISTS tolink_rag_db DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; @@ -40,8 +40,6 @@ CREATE TABLE IF NOT EXISTS llm_system_provider ( provider_type VARCHAR(32) NOT NULL COMMENT '厂商类型:openai/claude/glm/deepseek', provider_name VARCHAR(64) NOT NULL COMMENT '厂商展示名称,如 "OpenAI"', api_base_url VARCHAR(512) NOT NULL COMMENT '官方默认 API 地址', - supported_models JSON COMMENT '支持模型与能力映射', - config_schema JSON COMMENT '配置参数 Schema', is_active BOOLEAN NOT NULL DEFAULT TRUE COMMENT '是否启用', priority INT NOT NULL DEFAULT 50 COMMENT '厂商优先级(1-100)', created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -50,31 +48,51 @@ CREATE TABLE IF NOT EXISTS llm_system_provider ( UNIQUE KEY uk_provider_type (provider_type) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 AUTO_INCREMENT=10000 COMMENT 'LLM 系统级厂商配置表'; --- 3. 用户级 LLM 配置表 +-- 2.1 厂商模型能力目录表 +CREATE TABLE IF NOT EXISTS llm_provider_model ( + id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY COMMENT '主键', + provider_id BIGINT UNSIGNED NOT NULL COMMENT '关联 llm_system_provider.id', + model_name VARCHAR(128) NOT NULL COMMENT '模型名', + capability VARCHAR(32) NOT NULL COMMENT '单能力;一模型多能力=多行', + is_active BOOLEAN NOT NULL DEFAULT TRUE COMMENT '该模型能力是否上架', + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + + UNIQUE KEY uk_provider_model_cap (provider_id, model_name, capability), + INDEX idx_provider_cap (provider_id, capability) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 AUTO_INCREMENT=10000 COMMENT '厂商模型能力目录表'; + +-- 2.2 系统预设表 +CREATE TABLE IF NOT EXISTS llm_system_preset ( + id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY COMMENT '主键', + provider_id BIGINT UNSIGNED NOT NULL COMMENT '关联 llm_system_provider.id', + model_name VARCHAR(128) NOT NULL COMMENT '模型名', + capability VARCHAR(32) NOT NULL COMMENT '能力标识', + api_key VARCHAR(512) NOT NULL COMMENT '平台 Key(加密)', + is_active BOOLEAN NOT NULL DEFAULT TRUE COMMENT '是否对新用户下发', + created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + + UNIQUE KEY uk_preset_provider_model_cap (provider_id, model_name, capability) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 AUTO_INCREMENT=10000 COMMENT '系统预设表'; + +-- 3. 用户级 LLM 配置表(下游唯一生效源) CREATE TABLE IF NOT EXISTS llm_user_config ( id BIGINT UNSIGNED AUTO_INCREMENT PRIMARY KEY COMMENT '配置唯一标识', user_id BIGINT UNSIGNED NOT NULL COMMENT '用户 ID', provider_id BIGINT UNSIGNED NOT NULL COMMENT '关联 SystemProvider ID', - provider_type VARCHAR(32) NOT NULL COMMENT '厂商类型快照', - provider_name VARCHAR(64) NOT NULL COMMENT '厂商名称快照', - config_name VARCHAR(64) NOT NULL COMMENT '用户自定义配置名称', - api_key VARCHAR(512) NOT NULL COMMENT 'API Key(加密存储)', - custom_api_base_url VARCHAR(512) COMMENT '自定义 API 地址', + provider_type VARCHAR(32) NOT NULL COMMENT '厂商类型快照,下游路由 SDK', + api_key VARCHAR(512) NOT NULL COMMENT '厂商级 API Key(加密存储)', + api_base_url VARCHAR(512) COMMENT '实际生效地址:用户自定义或厂商默认', model_name VARCHAR(128) NOT NULL COMMENT '具体模型名', - priority INT NOT NULL DEFAULT 50 COMMENT '优先级 1-100', - is_active BOOLEAN NOT NULL DEFAULT TRUE COMMENT '是否启用', - is_default BOOLEAN NOT NULL DEFAULT FALSE COMMENT '是否为默认配置', - timeout_ms INT DEFAULT 60000 COMMENT '超时时间(毫秒)', - max_retries INT DEFAULT 3 COMMENT '最大重试次数', - stream_enabled BOOLEAN DEFAULT TRUE COMMENT '是否支持流式输出', - capability VARCHAR(32) NOT NULL DEFAULT 'CHAT' COMMENT '专用能力标识:CHAT/EMBEDDING/RERANK/OCR', - default_marker INT GENERATED ALWAYS AS (CASE WHEN is_default = 1 AND is_active = 1 THEN 1 ELSE NULL END) STORED COMMENT '默认判别生成列:default+active 时为 1,否则 NULL,仅用于唯一约束', - extra_config JSON COMMENT '扩展配置', + capability VARCHAR(32) NOT NULL DEFAULT 'CHAT' COMMENT '专用能力标识:CHAT/EMBEDDING/RERANK/OCR 等', + is_active BOOLEAN NOT NULL DEFAULT TRUE COMMENT '模型启停 + 生效过滤', + is_default BOOLEAN NOT NULL DEFAULT FALSE COMMENT '该能力是否生效(单用户单能力唯一)', + is_system_preset BOOLEAN NOT NULL DEFAULT FALSE COMMENT '系统预设行(只读)', created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - UNIQUE KEY uk_user_provider_model (user_id, provider_id, model_name), - UNIQUE KEY uq_user_default_per_capability (user_id, provider_type, capability, default_marker), + UNIQUE KEY uk_user_provider_model_capability (user_id, provider_id, model_name, capability, is_system_preset), INDEX idx_user_active_default (user_id, is_active, is_default), INDEX idx_user_provider_cap (user_id, provider_type, capability) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 AUTO_INCREMENT=10000 COMMENT '用户级 LLM 配置表'; diff --git a/scripts/db/schema.sql b/scripts/db/schema.sql index 4924ce7..0fd97e3 100644 --- a/scripts/db/schema.sql +++ b/scripts/db/schema.sql @@ -52,171 +52,59 @@ VALUES -- =============================================== -- 2. LLM 系统级厂商配置 llm_system_provider -- =============================================== --- provider_type 取值与 src/core/llm/factory.py 注册键严格对齐: --- openai / anthropic / glm / deepseek / qwen --- supported_models 为 Dict[str, List[str]](模型名 -> 能力标签列表), --- 能力标签集合:CHAT / EMBEDDING / RERANK / VISION / OCR / TOOL_CALLING --- api_base_url 取各 Provider 实现的 DEFAULT_API_BASE。 --- qwen 为系统默认 provider(config.py SYSTEM_LLM_PROVIDER="qwen"),优先级最高。 +-- 运行时生效配置由 llm_user_config 承载;模型能力目录见 llm_provider_model。 -- =============================================== INSERT INTO llm_system_provider - (id, provider_type, provider_name, api_base_url, supported_models, config_schema, is_active, priority, created_at, updated_at) + (id, provider_type, provider_name, api_base_url, is_active, priority, created_at, updated_at) VALUES - -- 10001 OpenAI:对话(含视觉/工具) + 文本向量化 - (10001, 'openai', 'OpenAI', 'https://api.openai.com/v1', - JSON_OBJECT( - 'gpt-4o', JSON_ARRAY('CHAT','VISION','TOOL_CALLING'), - 'gpt-4o-mini', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'gpt-4-turbo', JSON_ARRAY('CHAT','VISION','TOOL_CALLING'), - 'gpt-4', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'o3', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'o4-mini', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'text-embedding-3-small', JSON_ARRAY('EMBEDDING'), - 'text-embedding-3-large', JSON_ARRAY('EMBEDDING') - ), - JSON_OBJECT( - 'api_key', JSON_OBJECT('type','string', 'required', TRUE, 'label','API Key'), - 'api_base_url', JSON_OBJECT('type','string', 'required', FALSE, 'default','https://api.openai.com/v1'), - 'temperature', JSON_OBJECT('type','number', 'default', 0.7, 'min', 0, 'max', 2), - 'max_tokens', JSON_OBJECT('type','integer','default', 4096), - 'top_p', JSON_OBJECT('type','number', 'default', 1.0), - 'dimensions', JSON_OBJECT('type','integer','required', FALSE, 'note','仅 EMBEDDING 模型生效'), - 'timeout_ms', JSON_OBJECT('type','integer','default', 60000) - ), - TRUE, 85, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), - - -- 10002 Anthropic Claude:对话(含视觉/工具),无官方 embedding - (10002, 'anthropic', 'Anthropic Claude', 'https://api.anthropic.com/v1', - JSON_OBJECT( - 'claude-opus-4-8', JSON_ARRAY('CHAT','VISION','TOOL_CALLING'), - 'claude-sonnet-4-6', JSON_ARRAY('CHAT','VISION','TOOL_CALLING'), - 'claude-haiku-4-5', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'claude-3-5-sonnet-20241022', JSON_ARRAY('CHAT','VISION','TOOL_CALLING'), - 'claude-3-sonnet-20240229', JSON_ARRAY('CHAT','VISION') - ), - JSON_OBJECT( - 'api_key', JSON_OBJECT('type','string', 'required', TRUE, 'label','API Key'), - 'api_base_url', JSON_OBJECT('type','string', 'required', FALSE, 'default','https://api.anthropic.com/v1'), - 'temperature', JSON_OBJECT('type','number', 'default', 1.0, 'min', 0, 'max', 1), - 'max_tokens', JSON_OBJECT('type','integer','default', 8192), - 'top_p', JSON_OBJECT('type','number', 'default', 1.0), - 'timeout_ms', JSON_OBJECT('type','integer','default', 90000) - ), - TRUE, 80, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), - - -- 10003 智谱 GLM:对话 + 多模态(视觉/OCR) + 向量化 - (10003, 'glm', '智谱 GLM', 'https://open.bigmodel.cn/api/paas/v1', - JSON_OBJECT( - 'glm-4-plus', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'glm-4-air', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'glm-4-flash', JSON_ARRAY('CHAT'), - 'glm-4', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'glm-4v-plus', JSON_ARRAY('CHAT','VISION','OCR'), - 'embedding-3', JSON_ARRAY('EMBEDDING') - ), - JSON_OBJECT( - 'api_key', JSON_OBJECT('type','string', 'required', TRUE, 'label','API Key'), - 'api_base_url', JSON_OBJECT('type','string', 'required', FALSE, 'default','https://open.bigmodel.cn/api/paas/v1'), - 'temperature', JSON_OBJECT('type','number', 'default', 0.6, 'min', 0, 'max', 1), - 'max_tokens', JSON_OBJECT('type','integer','default', 4096), - 'dimensions', JSON_OBJECT('type','integer','default', 2048, 'note','embedding-3 维度'), - 'timeout_ms', JSON_OBJECT('type','integer','default', 60000) - ), - TRUE, 70, '2026-04-02 11:30:00', '2026-04-02 11:30:00'), - - -- 10004 DeepSeek:对话(含推理),OpenAI 兼容协议 - (10004, 'deepseek', 'DeepSeek', 'https://api.deepseek.com/v1', - JSON_OBJECT( - 'deepseek-chat', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'deepseek-reasoner', JSON_ARRAY('CHAT','TOOL_CALLING') - ), - JSON_OBJECT( - 'api_key', JSON_OBJECT('type','string', 'required', TRUE, 'label','API Key'), - 'api_base_url', JSON_OBJECT('type','string', 'required', FALSE, 'default','https://api.deepseek.com/v1'), - 'temperature', JSON_OBJECT('type','number', 'default', 0.5, 'min', 0, 'max', 2), - 'max_tokens', JSON_OBJECT('type','integer','default', 4096), - 'timeout_ms', JSON_OBJECT('type','integer','default', 60000) - ), - TRUE, 60, '2026-04-05 09:00:00', '2026-04-05 09:00:00'), - - -- 10005 Qwen(通义千问):系统默认 provider,覆盖 对话/视觉/OCR/重排/向量化 全能力 - (10005, 'qwen', '通义千问 Qwen', 'https://dashscope.aliyuncs.com/compatible-mode/v1', - JSON_OBJECT( - 'qwen-plus', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'qwen-max', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'qwen3.5-flash', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'qwen-vl-plus', JSON_ARRAY('CHAT','VISION','OCR'), - 'qwen-vl-max', JSON_ARRAY('CHAT','VISION','OCR'), - 'qwen3-vl-rerank', JSON_ARRAY('RERANK'), - 'text-embedding-v4', JSON_ARRAY('EMBEDDING') - ), - JSON_OBJECT( - 'api_key', JSON_OBJECT('type','string', 'required', TRUE, 'label','DashScope API Key'), - 'api_base_url', JSON_OBJECT('type','string', 'required', FALSE, 'default','https://dashscope.aliyuncs.com/compatible-mode/v1'), - 'temperature', JSON_OBJECT('type','number', 'default', 0.7, 'min', 0, 'max', 2), - 'max_tokens', JSON_OBJECT('type','integer','default', 4096), - 'top_p', JSON_OBJECT('type','number', 'default', 0.8), - 'dimensions', JSON_OBJECT('type','integer','default', 1024, 'note','text-embedding-v4 维度'), - 'timeout_ms', JSON_OBJECT('type','integer','default', 60000) - ), - TRUE, 95, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), - - -- 10006 OpenAI 兼容自建网关:演示一条 is_active=FALSE 的停用厂商 - (10006, 'openai', 'OpenAI 兼容网关(停用)', 'https://llm-gateway.internal.example.com/v1', - JSON_OBJECT( - 'qwen2.5-72b-instruct', JSON_ARRAY('CHAT','TOOL_CALLING'), - 'bge-large-zh-v1.5', JSON_ARRAY('EMBEDDING') - ), - JSON_OBJECT( - 'api_key', JSON_OBJECT('type','string', 'required', TRUE), - 'api_base_url', JSON_OBJECT('type','string', 'required', TRUE), - 'temperature', JSON_OBJECT('type','number', 'default', 0.7), - 'timeout_ms', JSON_OBJECT('type','integer','default', 120000) - ), - FALSE, 30, '2026-05-06 14:00:00', '2026-05-18 09:20:00'); + (10001, 'openai', 'OpenAI', 'https://api.openai.com/v1', TRUE, 85, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10002, 'claude', 'Anthropic Claude', 'https://api.anthropic.com/v1', TRUE, 80, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10003, 'glm', '智谱 GLM', 'https://open.bigmodel.cn/api/paas/v1', TRUE, 70, '2026-04-02 11:30:00', '2026-04-02 11:30:00'), + (10004, 'deepseek', 'DeepSeek', 'https://api.deepseek.com/v1', TRUE, 60, '2026-04-05 09:00:00', '2026-04-05 09:00:00'), + (10005, 'aliyun', '通义千问 Qwen', 'https://dashscope.aliyuncs.com/compatible-mode/v1', TRUE, 95, '2026-04-01 10:00:00', '2026-04-01 10:00:00'); + +INSERT INTO llm_provider_model + (id, provider_id, model_name, capability, is_active, created_at, updated_at) +VALUES + (10001, 10001, 'gpt-4o', 'CHAT', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10002, 10001, 'text-embedding-3-small', 'EMBEDDING', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10003, 10002, 'claude-sonnet-4-6', 'CHAT', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10004, 10003, 'glm-4-plus', 'CHAT', TRUE, '2026-04-02 11:30:00', '2026-04-02 11:30:00'), + (10005, 10003, 'embedding-3', 'EMBEDDING', TRUE, '2026-04-02 11:30:00', '2026-04-02 11:30:00'), + (10006, 10004, 'deepseek-chat', 'CHAT', TRUE, '2026-04-05 09:00:00', '2026-04-05 09:00:00'), + (10007, 10005, 'qwen-plus', 'CHAT', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10008, 10005, 'text-embedding-v4', 'EMBEDDING', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10009, 10005, 'qwen3-vl-rerank', 'RERANK', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10010, 10005, 'qwen-vl-plus', 'VISION', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10011, 10005, 'qwen-vl-plus', 'OCR', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'); + +INSERT INTO llm_system_preset + (id, provider_id, model_name, capability, api_key, is_active, created_at, updated_at) +VALUES + (10001, 10005, 'qwen-plus', 'CHAT', 'demo-encrypted-key-chat', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10002, 10005, 'text-embedding-v4', 'EMBEDDING', 'demo-encrypted-key-embedding', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'), + (10003, 10005, 'qwen3-vl-rerank', 'RERANK', 'demo-encrypted-key-rerank', TRUE, '2026-04-01 10:00:00', '2026-04-01 10:00:00'); -- =============================================== -- 3. 用户级 LLM 配置 llm_user_config --- api_key 为加密后密文(演示用占位 Fernet 风格 token) +-- api_key 为加密后密文;演示数据使用占位密文,真实种子由 Java/Python 对齐后写入。 -- =============================================== INSERT INTO llm_user_config - (id, user_id, provider_id, provider_type, provider_name, config_name, api_key, custom_api_base_url, - model_name, priority, is_active, is_default, timeout_ms, max_retries, stream_enabled, capability, extra_config, created_at, updated_at) + (id, user_id, provider_id, provider_type, api_key, api_base_url, model_name, + capability, is_active, is_default, is_system_preset, created_at, updated_at) VALUES - -- 张伟:OpenAI 对话(默认)+ OpenAI 向量化(默认) - (10001, 10002, 10001, 'openai', 'OpenAI', '我的GPT-4o对话', - 'gAAAAABm1cQ8xJ3kL5mN7oP9qZl2QyB6cR4sD8fG1hJ3kL5mN7oP9q-encrypted-key-01', NULL, - 'gpt-4o', 90, TRUE, TRUE, 60000, 3, TRUE, 'CHAT', - JSON_OBJECT('temperature', 0.7, 'top_p', 1.0), '2026-04-03 15:00:00', '2026-05-28 10:11:00'), - - (10002, 10002, 10001, 'openai', 'OpenAI', 'OpenAI向量化', - 'gAAAAABm1cQ8aD5fG8hK0lM2nP4qS6tU8wY0a7sH2kLpQ9mNvR4tYuB-encrypted-key-02', NULL, - 'text-embedding-3-small', 90, TRUE, TRUE, 30000, 3, FALSE, 'EMBEDDING', - JSON_OBJECT('dimensions', 1536), '2026-04-03 15:05:00', '2026-04-03 15:05:00'), - - -- 张伟:Claude 对话(备用,非默认) - (10003, 10002, 10002, 'claude', 'Anthropic Claude', 'Claude兜底', - 'gAAAAABm1cQ8nP4qS6tU8wY0a7sH2kLpQ9mNvR4tYuB6cXeJ1oZ3aD-encrypted-key-03', NULL, - 'claude-sonnet-4-6', 70, TRUE, FALSE, 90000, 2, TRUE, 'CHAT', - JSON_OBJECT('temperature', 1.0), '2026-04-10 09:20:00', '2026-04-10 09:20:00'), - - -- 李芳:GLM 对话(默认)+ GLM 向量化(默认) - (10004, 10003, 10003, 'glm', '智谱 GLM', '智谱GLM对话', - 'gAAAAABm1cQ8tYuB6cXeJ1oZ3aD5fG8hK0lM2nP4qS6tU8wY0a7sH2-encrypted-key-04', NULL, - 'glm-4-plus', 80, TRUE, TRUE, 60000, 3, TRUE, 'CHAT', - JSON_OBJECT('temperature', 0.6), '2026-04-08 10:00:00', '2026-05-29 14:00:00'), - - (10005, 10003, 10003, 'glm', '智谱 GLM', '智谱向量化', - 'gAAAAABm1cQ8cXeJ1oZ3aD5fG8hK0lM2nP4qS6tU8wY0a7sH2kLpQ9-encrypted-key-05', NULL, - 'embedding-3', 80, TRUE, TRUE, 30000, 3, FALSE, 'EMBEDDING', - JSON_OBJECT('dimensions', 2048), '2026-04-08 10:05:00', '2026-04-08 10:05:00'), - - -- 李芳:DeepSeek 自建网关对话(默认未启用) - (10006, 10003, 10004, 'deepseek', 'DeepSeek', 'DeepSeek私有网关', - 'gAAAAABm1cQ8B6cXeJ1oZ3aD5fG8hK0lM2nP4qS6tU8wY0a7sH2kLp-encrypted-key-06', - 'https://llm-gateway.internal.lifang.com/v1', - 'deepseek-chat', 50, FALSE, FALSE, 60000, 3, TRUE, 'CHAT', - JSON_OBJECT('temperature', 0.5), '2026-05-01 16:00:00', '2026-05-15 09:30:00'); + (10001, 10002, 10005, 'aliyun', 'demo-encrypted-key-chat', 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'qwen-plus', 'CHAT', TRUE, TRUE, TRUE, '2026-04-03 15:00:00', '2026-05-28 10:11:00'), + (10002, 10002, 10005, 'aliyun', 'demo-encrypted-key-embedding', 'https://dashscope.aliyuncs.com/compatible-mode/v1', + 'text-embedding-v4', 'EMBEDDING', TRUE, TRUE, TRUE, '2026-04-03 15:05:00', '2026-04-03 15:05:00'), + (10003, 10002, 10002, 'claude', 'demo-encrypted-key-claude', 'https://api.anthropic.com/v1', + 'claude-sonnet-4-6', 'CHAT', TRUE, FALSE, FALSE, '2026-04-10 09:20:00', '2026-04-10 09:20:00'), + (10004, 10003, 10003, 'glm', 'demo-encrypted-key-glm-chat', 'https://open.bigmodel.cn/api/paas/v1', + 'glm-4-plus', 'CHAT', TRUE, TRUE, FALSE, '2026-04-08 10:00:00', '2026-05-29 14:00:00'), + (10005, 10003, 10003, 'glm', 'demo-encrypted-key-glm-embedding', 'https://open.bigmodel.cn/api/paas/v1', + 'embedding-3', 'EMBEDDING', TRUE, TRUE, FALSE, '2026-04-08 10:05:00', '2026-04-08 10:05:00'), + (10006, 10003, 10004, 'deepseek', 'demo-encrypted-key-deepseek', 'https://llm-gateway.internal.lifang.com/v1', + 'deepseek-chat', 'CHAT', FALSE, FALSE, FALSE, '2026-05-01 16:00:00', '2026-05-15 09:30:00'); -- =============================================== -- 4. 数据集 dataset(含软删列) @@ -513,4 +401,4 @@ SET FOREIGN_KEY_CHECKS = 1; -- document_parsed_log : 5 (含 1 失败 + 1 重试链) -- document_parse_pipeline : 5 (成功 / 失败被接班 / 进行中) -- kb_document_chunk : 10 (含 sparse PENDING 与 1 个 REMOVED) --- =============================================== \ No newline at end of file +-- =============================================== diff --git a/src/api/routes/internal.py b/src/api/routes/internal.py index 9f54952..20259c6 100644 --- a/src/api/routes/internal.py +++ b/src/api/routes/internal.py @@ -2,6 +2,7 @@ 内部接口路由 供 Java 管理端查询配置和用量(不暴露给外部) """ + from typing import Optional from datetime import datetime @@ -40,8 +41,7 @@ async def get_system_providers( "provider_type": p.get("provider_type"), "provider_name": p.get("provider_name"), "api_base_url": p.get("api_base_url"), - "supported_models": p.get("supported_models", []), - "config_schema": p.get("config_schema"), + "models": p.get("models", {}), "is_active": p.get("is_active", True), } for p in providers @@ -78,17 +78,14 @@ async def get_user_configs( items = [ { "id": c.get("id"), - "config_name": c.get("config_name"), "provider_type": c.get("provider_type"), - "provider_name": c.get("provider_name"), "model_name": c.get("model_name"), + "capability": c.get("capability"), "api_key_masked": mask_api_key(c.get("api_key", "")), - "custom_api_base_url": c.get("custom_api_base_url"), - "priority": c.get("priority"), + "api_base_url": c.get("api_base_url"), "is_active": c.get("is_active"), "is_default": c.get("is_default"), - "stream_enabled": c.get("stream_enabled"), - "extra_config": c.get("extra_config"), + "is_system_preset": c.get("is_system_preset"), } for c in configs ] @@ -137,4 +134,4 @@ async def get_user_usage( ) except Exception as e: - return APIResponse(code=500, message=str(e), data=None) \ No newline at end of file + return APIResponse(code=500, message=str(e), data=None) diff --git a/src/config.py b/src/config.py index 7775028..71fe34a 100644 --- a/src/config.py +++ b/src/config.py @@ -61,7 +61,11 @@ def assemble_redis_url(cls, v: Optional[str], info) -> str: # ========================================== # 安全配置 (Security) # ========================================== - API_KEY_ENCRYPTION_SECRET: str = "default-secret" + # 64-character hex string; decoded to 32 bytes for AES-256-GCM. + # Local placeholder only; production must override it with the Java-side secret. + API_KEY_ENCRYPTION_SECRET: str = ( + "0000000000000000000000000000000000000000000000000000000000000000" + ) # ========================================== # 内部召回 API 配置 (Internal Recall API) diff --git a/src/core/llm/encryption.py b/src/core/llm/encryption.py index 5162260..6bcab62 100644 --- a/src/core/llm/encryption.py +++ b/src/core/llm/encryption.py @@ -2,12 +2,27 @@ API Key 加密工具 使用 AES-256-GCM 对用户 API Key 进行加密存储 """ + import base64 import os + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + from src.config import settings +def _get_aes_key() -> bytes: + """读取 Java 对齐的 64 位 hex AES-256 key。""" + secret = settings.API_KEY_ENCRYPTION_SECRET.strip() + try: + key = bytes.fromhex(secret) + except ValueError as exc: + raise ValueError("API_KEY_ENCRYPTION_SECRET must be a 64-character hex string") from exc + if len(key) != 32: + raise ValueError("API_KEY_ENCRYPTION_SECRET must decode to 32 bytes") + return key + + def encrypt_api_key(api_key: str) -> str: """加密 API Key @@ -17,8 +32,7 @@ def encrypt_api_key(api_key: str) -> str: Returns: Base64 编码的加密字符串 (IV + ciphertext) """ - key = settings.API_KEY_ENCRYPTION_SECRET.encode()[:32].ljust(32, b'\0') - aesgcm = AESGCM(key) + aesgcm = AESGCM(_get_aes_key()) iv = os.urandom(12) # 96-bit IV for GCM ciphertext = aesgcm.encrypt(iv, api_key.encode(), None) # 拼接 IV + ciphertext @@ -35,8 +49,7 @@ def decrypt_api_key(encrypted: str) -> str: Returns: 明文 API Key """ - key = settings.API_KEY_ENCRYPTION_SECRET.encode()[:32].ljust(32, b'\0') - aesgcm = AESGCM(key) + aesgcm = AESGCM(_get_aes_key()) data = base64.b64decode(encrypted.encode()) iv = data[:12] ciphertext = data[12:] diff --git a/src/core/llm/factory.py b/src/core/llm/factory.py index 5f9a48c..9e77be8 100644 --- a/src/core/llm/factory.py +++ b/src/core/llm/factory.py @@ -2,6 +2,7 @@ ModelFactory 注册式工厂 按 Capability 分发 Provider,支持动态注册新 Provider """ + from typing import Dict, Type, Optional, Any from src.core.llm.base_provider import BaseProvider @@ -20,6 +21,7 @@ class ModelFactory: _instance: Optional["ModelFactory"] = None _providers: Dict[str, Type[BaseProvider]] = {} + _provider_aliases = {"claude": "anthropic", "aliyun": "qwen"} _default_provider_types = {"openai", "anthropic", "glm", "deepseek", "qwen"} def __new__(cls) -> "ModelFactory": @@ -53,6 +55,12 @@ def _ensure_default_provider_available(self, provider_type: str) -> None: if provider_type in self._default_provider_types and provider_type not in self._providers: self._register_default_providers() + @classmethod + def normalize_provider_type(cls, provider_type: str | None) -> str: + """归一化 Java/DB provider_type 到 Python provider 注册键。""" + raw = (provider_type or "openai").lower() + return cls._provider_aliases.get(raw, raw) + def register_provider(self, provider_type: str, provider_cls: Type[BaseProvider]) -> None: """注册 Provider @@ -79,10 +87,11 @@ def get_provider_class(self, provider_type: str) -> Type[BaseProvider]: Raises: KeyError: 如果该类型未注册 """ - self._ensure_default_provider_available(provider_type) - if provider_type not in self._providers: + normalized = self.normalize_provider_type(provider_type) + self._ensure_default_provider_available(normalized) + if normalized not in self._providers: raise KeyError(f"Provider type '{provider_type}' is not registered") - return self._providers[provider_type] + return self._providers[normalized] def create_client( self, @@ -90,7 +99,7 @@ def create_client( api_key: str, api_base_url: Optional[str] = None, model_name: Optional[str] = None, - **kwargs + **kwargs, ) -> BaseProvider: """创建 Provider 实例 @@ -104,14 +113,15 @@ def create_client( Returns: Provider 实例 """ - provider_cls = self.get_provider_class(provider_type) + normalized = self.normalize_provider_type(provider_type) + provider_cls = self.get_provider_class(normalized) return provider_cls( - provider_type=provider_type, - provider_name=provider_type, + provider_type=normalized, + provider_name=normalized, api_key=api_key, api_base_url=api_base_url, model_name=model_name, - **kwargs + **kwargs, ) def list_registered_providers(self) -> list[str]: @@ -127,16 +137,17 @@ def get_provider_info(self, provider_type: str) -> Dict[str, Any]: Returns: Provider 信息字典 """ - provider_cls = self.get_provider_class(provider_type) + normalized = self.normalize_provider_type(provider_type) + provider_cls = self.get_provider_class(normalized) # 创建临时实例获取能力信息 temp_instance = provider_cls( - provider_type=provider_type, - provider_name=provider_type, + provider_type=normalized, + provider_name=normalized, api_key="", ) return { - "type": provider_type, + "type": normalized, "capabilities": [c.value for c in temp_instance.get_capabilities()], } diff --git a/src/core/llm/user_model_resolver.py b/src/core/llm/user_model_resolver.py index ce73d14..9201e1d 100644 --- a/src/core/llm/user_model_resolver.py +++ b/src/core/llm/user_model_resolver.py @@ -13,6 +13,7 @@ 缓存策略:本期不启用 Redis 配置缓存,读配置统一 ``use_cache=False`` 直读 DB。 """ + from __future__ import annotations from dataclasses import dataclass @@ -41,6 +42,12 @@ } +def normalize_provider_type(provider_type: str | None) -> str: + """归一化 Java/DB provider_type 到 Python provider 注册键。""" + raw = (provider_type or "openai").lower() + return {"claude": "anthropic", "aliyun": "qwen"}.get(raw, raw) + + @dataclass class ResolvedModel: """一次解析的产物:可直接使用的 Provider + 元信息。""" @@ -64,7 +71,7 @@ def build_provider_from_config( Args: config: 配置字典,形如 ``ConfigReaderService`` 返回结构(含 provider_type / - api_key / custom_api_base_url / model_name;系统兜底配置带 ``is_system_fallback``)。 + api_key / api_base_url / model_name;系统兜底配置带 ``is_system_fallback``)。 capability: 能力字符串(CHAT/EMBEDDING/RERANK/VISION/OCR),用于 ``has_capability`` 校验。 fallback_model: 配置未指定 ``model_name`` 时的回退模型名。 override_model: 调用方显式指定、优先级最高的模型名(如 ``/llm`` 路由的 ``request.model``)。 @@ -86,13 +93,13 @@ def build_provider_from_config( raw_key = config.get("api_key", "") api_key = decrypt_api_key(raw_key) if raw_key else "" - provider_type = config.get("provider_type") or "openai" + provider_type = normalize_provider_type(config.get("provider_type")) model_name = override_model or config.get("model_name") or fallback_model provider = ModelFactory().create_client( provider_type=provider_type, api_key=api_key or "", - api_base_url=config.get("custom_api_base_url"), + api_base_url=config.get("api_base_url"), model_name=model_name, timeout_ms=settings.MARKDOWN_PARSER_LLM_TIMEOUT_MS, ) diff --git a/src/models/db_models.py b/src/models/db_models.py index 9642e3e..77a4620 100644 --- a/src/models/db_models.py +++ b/src/models/db_models.py @@ -2,20 +2,18 @@ SQLAlchemy ORM 模型 对应 MySQL 数据库表结构 """ + from datetime import datetime -from typing import Dict, List, Optional +from typing import List, Optional from sqlalchemy import ( BigInteger, Boolean, - Computed, DateTime, ForeignKey, Index, Integer, - JSON, String, - Text, UniqueConstraint, func, ) @@ -24,6 +22,7 @@ class Base(DeclarativeBase): """SQLAlchemy 声明式基类""" + pass @@ -32,14 +31,13 @@ class SystemProviderDB(Base): 表:llm_system_provider """ + __tablename__ = "llm_system_provider" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) provider_type: Mapped[str] = mapped_column(String(32), nullable=False, unique=True) provider_name: Mapped[str] = mapped_column(String(64), nullable=False) api_base_url: Mapped[str] = mapped_column(String(512), nullable=False) - supported_models: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) - config_schema: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) priority: Mapped[int] = mapped_column(Integer, default=50, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False) @@ -51,6 +49,82 @@ class SystemProviderDB(Base): user_configs: Mapped[List["UserLLMConfigDB"]] = relationship( "UserLLMConfigDB", back_populates="provider" ) + provider_models: Mapped[List["ProviderModelDB"]] = relationship( + "ProviderModelDB", back_populates="provider" + ) + system_presets: Mapped[List["SystemPresetDB"]] = relationship( + "SystemPresetDB", back_populates="provider" + ) + + +class ProviderModelDB(Base): + """厂商模型能力目录 + + 表:llm_provider_model + """ + + __tablename__ = "llm_provider_model" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) + provider_id: Mapped[int] = mapped_column( + BigInteger, ForeignKey("llm_system_provider.id"), nullable=False + ) + model_name: Mapped[str] = mapped_column(String(128), nullable=False) + capability: Mapped[str] = mapped_column(String(32), nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=func.now(), onupdate=func.now(), nullable=False + ) + + provider: Mapped["SystemProviderDB"] = relationship( + "SystemProviderDB", back_populates="provider_models" + ) + + __table_args__ = ( + UniqueConstraint( + "provider_id", + "model_name", + "capability", + name="uk_provider_model_cap", + ), + Index("idx_provider_cap", "provider_id", "capability"), + ) + + +class SystemPresetDB(Base): + """系统预设模板 + + 表:llm_system_preset + """ + + __tablename__ = "llm_system_preset" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) + provider_id: Mapped[int] = mapped_column( + BigInteger, ForeignKey("llm_system_provider.id"), nullable=False + ) + model_name: Mapped[str] = mapped_column(String(128), nullable=False) + capability: Mapped[str] = mapped_column(String(32), nullable=False) + api_key: Mapped[str] = mapped_column(String(512), nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=func.now(), onupdate=func.now(), nullable=False + ) + + provider: Mapped["SystemProviderDB"] = relationship( + "SystemProviderDB", back_populates="system_presets" + ) + + __table_args__ = ( + UniqueConstraint( + "provider_id", + "model_name", + "capability", + name="uk_preset_provider_model_cap", + ), + ) class UserLLMConfigDB(Base): @@ -58,10 +132,10 @@ class UserLLMConfigDB(Base): 表:llm_user_config - 新增 capability 字段支持按能力配置不同模型: - - 同一用户可以为 CHAT、EMBEDDING、RERANK 等配置不同的模型 - - is_default 在 (user_id, provider_type, capability) 范围内唯一 + 系统预设与用户自配统一汇入本表,Python 按 + (user_id, capability, is_default, is_active) 读取生效配置。 """ + __tablename__ = "llm_user_config" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) @@ -70,32 +144,13 @@ class UserLLMConfigDB(Base): BigInteger, ForeignKey("llm_system_provider.id"), nullable=False ) provider_type: Mapped[str] = mapped_column(String(32), nullable=False) - provider_name: Mapped[str] = mapped_column(String(64), nullable=False) - config_name: Mapped[str] = mapped_column(String(64), nullable=False) api_key: Mapped[str] = mapped_column(String(512), nullable=False) - custom_api_base_url: Mapped[Optional[str]] = mapped_column(String(512), nullable=True) + api_base_url: Mapped[Optional[str]] = mapped_column(String(512), nullable=True) model_name: Mapped[str] = mapped_column(String(128), nullable=False) - priority: Mapped[int] = mapped_column(Integer, default=50, nullable=False) + capability: Mapped[str] = mapped_column(String(32), default="CHAT", nullable=False) is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) is_default: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) - timeout_ms: Mapped[int] = mapped_column(Integer, default=60000, nullable=False) - max_retries: Mapped[int] = mapped_column(Integer, default=3, nullable=False) - stream_enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) - extra_config: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True) - # 新增字段:主要能力类型 - capability: Mapped[str] = mapped_column(String(32), default="CHAT", nullable=False) - # 唯一约束判别列(生成列):仅当 is_default 且 is_active 时为 1,否则 NULL。 - # MySQL 唯一索引中 NULL 不算重复,故与 (user_id, provider_type, capability) 组成 - # 唯一键后,效果是「每个 (user_id, provider_type, capability) 至多一条默认且启用的配置」, - # 非默认/停用配置(值为 NULL)不受限制。应用层不应写入此列(GENERATED ALWAYS)。 - default_marker: Mapped[Optional[int]] = mapped_column( - Integer, - Computed( - "(CASE WHEN is_default = 1 AND is_active = 1 THEN 1 ELSE NULL END)", - persisted=True, - ), - nullable=True, - ) + is_system_preset: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime, default=func.now(), onupdate=func.now(), nullable=False @@ -105,19 +160,19 @@ class UserLLMConfigDB(Base): provider: Mapped["SystemProviderDB"] = relationship( "SystemProviderDB", back_populates="user_configs" ) - usage_logs: Mapped[List["UsageLogDB"]] = relationship( - "UsageLogDB", back_populates="config" - ) + usage_logs: Mapped[List["UsageLogDB"]] = relationship("UsageLogDB", back_populates="config") __table_args__ = ( - Index("idx_user_provider_cap", "user_id", "provider_type", "capability"), UniqueConstraint( "user_id", - "provider_type", + "provider_id", + "model_name", "capability", - "default_marker", - name="uq_user_default_per_capability", + "is_system_preset", + name="uk_user_provider_model_capability", ), + Index("idx_user_active_default", "user_id", "is_active", "is_default"), + Index("idx_user_provider_cap", "user_id", "provider_type", "capability"), ) @@ -126,6 +181,7 @@ class UsageLogDB(Base): 表:llm_usage_log """ + __tablename__ = "llm_usage_log" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True) @@ -146,12 +202,10 @@ class UsageLogDB(Base): created_at: Mapped[datetime] = mapped_column(DateTime, default=func.now(), nullable=False) # 关系 - config: Mapped["UserLLMConfigDB"] = relationship( - "UserLLMConfigDB", back_populates="usage_logs" - ) + config: Mapped["UserLLMConfigDB"] = relationship("UserLLMConfigDB", back_populates="usage_logs") __table_args__ = ( Index("idx_user_date", "user_id", "created_at"), Index("idx_config_date", "config_id", "created_at"), Index("idx_conversation_id", "conversation_id"), - ) \ No newline at end of file + ) diff --git a/src/models/system_provider.py b/src/models/system_provider.py index ae6db9b..35d51b6 100644 --- a/src/models/system_provider.py +++ b/src/models/system_provider.py @@ -2,8 +2,8 @@ SystemProvider ORM 模型 对应 llm_system_provider 表 """ + from datetime import datetime -from typing import Dict, List, Optional from pydantic import BaseModel, Field @@ -13,15 +13,11 @@ class SystemProvider(BaseModel): 表:llm_system_provider """ + id: str = Field(..., description="厂商唯一标识 (UUID)") - provider_type: str = Field(..., description="厂商类型:openai/claude/glm/deepseek") + provider_type: str = Field(..., description="厂商类型:openai/claude/aliyun/glm/deepseek") provider_name: str = Field(..., description="厂商展示名称") api_base_url: str = Field(..., description="官方默认 API 地址") - supported_models: Dict[str, List[str]] = Field( - default_factory=dict, - description="支持模型与能力映射,如 {\"gpt-4\":[\"CHAT\",\"OCR\"]}" - ) - config_schema: Optional[dict] = Field(None, description="配置参数 Schema") is_active: bool = Field(True, description="是否启用") priority: int = Field(50, description="厂商优先级") created_at: datetime = Field(default_factory=datetime.now) diff --git a/src/models/user_llm_config.py b/src/models/user_llm_config.py index 7b697d3..292e7c0 100644 --- a/src/models/user_llm_config.py +++ b/src/models/user_llm_config.py @@ -2,8 +2,9 @@ UserLLMConfig ORM 模型 对应 llm_user_config 表 """ + from datetime import datetime -from typing import List, Optional +from typing import Optional from pydantic import BaseModel, Field @@ -13,20 +14,18 @@ class UserLLMConfig(BaseModel): 表:llm_user_config """ + id: str = Field(..., description="配置唯一标识 (UUID)") user_id: str = Field(..., description="用户 ID") provider_id: str = Field(..., description="关联 SystemProvider ID") - config_name: str = Field(..., description="用户自定义配置名称") + provider_type: str = Field(..., description="厂商类型快照") api_key: str = Field(..., description="用户提供的 API Key(加密存储)") - custom_api_base_url: Optional[str] = Field(None, description="自定义 API 地址") + api_base_url: Optional[str] = Field(None, description="实际生效 API 地址") model_name: str = Field(..., description="具体模型名") - priority: int = Field(50, ge=1, le=100, description="优先级 1-100") + capability: str = Field("CHAT", description="能力类型") is_active: bool = Field(True, description="是否启用") - is_default: bool = Field(False, description="是否为用户默认模型") - timeout_ms: int = Field(60000, description="超时时间(ms)") - max_retries: int = Field(3, description="最大重试次数") - stream_enabled: bool = Field(True, description="是否支持流式输出") - extra_config: Optional[dict] = Field(None, description="扩展配置") + is_default: bool = Field(False, description="该能力是否生效") + is_system_preset: bool = Field(False, description="是否为系统预设行") created_at: datetime = Field(default_factory=datetime.now) updated_at: datetime = Field(default_factory=datetime.now) diff --git a/src/services/config_reader_service.py b/src/services/config_reader_service.py index f2cce24..82a312a 100644 --- a/src/services/config_reader_service.py +++ b/src/services/config_reader_service.py @@ -1,9 +1,10 @@ +"""ConfigReaderService 配置读取服务。 + +Java 管理端负责写入 LLM 配置,Python 只读取运行时生效配置: +``llm_user_config`` 中 ``user_id + capability + is_default + is_active`` 命中的记录。 """ -ConfigReaderService 配置读取服务 -从数据库读取 LLM 配置,支持 Redis 缓存 -""" -import json -from typing import Any, Dict, List, Optional, Union + +from typing import Any, Dict, List, Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -14,15 +15,21 @@ from src.models.db_models import SystemProviderDB, UserLLMConfigDB -def _parse_json_field(value: Union[str, dict, list, None]) -> Optional[Any]: - """解析 JSON 字段,兼容字符串和已转换的字典类型""" - if value is None: - return None - if isinstance(value, (dict, list)): - return value - if isinstance(value, str): - return json.loads(value) - return value +def _user_config_to_dict(cfg: UserLLMConfigDB) -> Dict[str, Any]: + """将 ORM 用户配置行转换为运行时配置字典。""" + return { + "id": cfg.id, + "user_id": cfg.user_id, + "provider_id": cfg.provider_id, + "provider_type": cfg.provider_type, + "api_key": cfg.api_key, + "api_base_url": cfg.api_base_url, + "model_name": cfg.model_name, + "capability": cfg.capability, + "is_active": cfg.is_active, + "is_default": cfg.is_default, + "is_system_preset": cfg.is_system_preset, + } class ConfigReaderService: @@ -57,9 +64,7 @@ def set_db(self, db: AsyncSession) -> None: """设置数据库 Session""" self._db = db - async def get_user_configs( - self, user_id: int, use_cache: bool = True - ) -> List[Dict[str, Any]]: + async def get_user_configs(self, user_id: int, use_cache: bool = True) -> List[Dict[str, Any]]: """获取用户的所有 LLM 配置 Args: @@ -83,36 +88,14 @@ async def get_user_configs( stmt = ( select(UserLLMConfigDB) - .options(selectinload(UserLLMConfigDB.provider)) .where(UserLLMConfigDB.user_id == user_id) .where(UserLLMConfigDB.is_active == True) - .order_by(UserLLMConfigDB.priority.desc()) + .order_by(UserLLMConfigDB.id.desc()) ) result = await self._db.execute(stmt) configs_db = result.scalars().all() - configs = [] - for cfg in configs_db: - provider = cfg.provider - configs.append({ - "id": cfg.id, - "user_id": cfg.user_id, - "provider_id": cfg.provider_id, - "provider_type": provider.provider_type if provider else None, - "provider_name": provider.provider_name if provider else None, - "config_name": cfg.config_name, - "api_key": cfg.api_key, # 加密存储 - "custom_api_base_url": cfg.custom_api_base_url, - "model_name": cfg.model_name, - "priority": cfg.priority, - "is_active": cfg.is_active, - "is_default": cfg.is_default, - "timeout_ms": cfg.timeout_ms, - "max_retries": cfg.max_retries, - "stream_enabled": cfg.stream_enabled, - "extra_config": _parse_json_field(cfg.extra_config), - "capability": cfg.capability, # 新增字段 - }) + configs = [_user_config_to_dict(cfg) for cfg in configs_db] # 回填缓存 if use_cache: @@ -120,81 +103,12 @@ async def get_user_configs( return configs - async def get_user_default_config( - self, user_id: int, use_cache: bool = True - ) -> Optional[Dict[str, Any]]: - """获取用户默认 LLM 配置 - - Args: - user_id: 用户 ID - use_cache: 是否使用缓存 - - Returns: - 默认配置,未设置则返回 None - """ - cache_key = self._cache.user_default_key(str(user_id)) - - # 先查缓存 - if use_cache: - cached = await self._cache.get(cache_key) - if cached is not None: - return cached - - # 缓存未命中,从数据库查询 - if self._db is None: - return None - - # 默认配置理论上 (user_id) 维度唯一,但 schema 未强制;用 order_by + limit(1) - # 确定性取一条(priority 高者优先),避免脏数据下 scalar_one_or_none 抛 - # MultipleResultsFound 被上层误判为「读取失败(可重试)」。 - stmt = ( - select(UserLLMConfigDB) - .options(selectinload(UserLLMConfigDB.provider)) - .where(UserLLMConfigDB.user_id == user_id) - .where(UserLLMConfigDB.is_default == True) - .where(UserLLMConfigDB.is_active == True) - .order_by(UserLLMConfigDB.priority.desc(), UserLLMConfigDB.id.desc()) - .limit(1) - ) - result = await self._db.execute(stmt) - cfg = result.scalars().first() - - if cfg is None: - return None - - provider = cfg.provider - config = { - "id": cfg.id, - "user_id": cfg.user_id, - "provider_id": cfg.provider_id, - "provider_type": provider.provider_type if provider else None, - "provider_name": provider.provider_name if provider else None, - "config_name": cfg.config_name, - "api_key": cfg.api_key, - "custom_api_base_url": cfg.custom_api_base_url, - "model_name": cfg.model_name, - "priority": cfg.priority, - "is_active": cfg.is_active, - "is_default": cfg.is_default, - "timeout_ms": cfg.timeout_ms, - "max_retries": cfg.max_retries, - "stream_enabled": cfg.stream_enabled, - "extra_config": _parse_json_field(cfg.extra_config), - "capability": cfg.capability, # 新增字段 - } - - # 回填缓存 - if use_cache: - await self._cache.set(cache_key, config) - - return config - async def get_user_default_config_by_capability( self, user_id: int, capability: str, provider_type: Optional[str] = None, - use_cache: bool = True + use_cache: bool = True, ) -> Optional[Dict[str, Any]]: """获取用户指定能力的默认 LLM 配置 @@ -207,7 +121,8 @@ async def get_user_default_config_by_capability( Returns: 该能力的默认配置,未设置则返回 None """ - cache_key = f"llm:user:{user_id}:default:{capability}" + capability_upper = capability.upper() + cache_key = f"llm:user:{user_id}:default:{capability_upper}" if provider_type: cache_key = f"{cache_key}:{provider_type}" @@ -221,49 +136,27 @@ async def get_user_default_config_by_capability( if self._db is None: return None - # 同 get_user_default_config:(user_id, provider_type, capability) 默认唯一靠 schema - # 约束保证;为防约束缺位/脏数据,查询侧用 order_by + limit(1) 确定性取一条, + # 默认配置的业务唯一性由 Java 管理端保证;为防脏数据,查询侧用 limit(1) + # 确定性取最新一条, # 不让 MultipleResultsFound 冒泡成「读取失败(可重试)」误判。 stmt = ( select(UserLLMConfigDB) - .options(selectinload(UserLLMConfigDB.provider)) .where(UserLLMConfigDB.user_id == user_id) - .where(UserLLMConfigDB.capability == capability.upper()) + .where(UserLLMConfigDB.capability == capability_upper) .where(UserLLMConfigDB.is_default == True) .where(UserLLMConfigDB.is_active == True) ) if provider_type: stmt = stmt.where(UserLLMConfigDB.provider_type == provider_type) - stmt = stmt.order_by( - UserLLMConfigDB.priority.desc(), UserLLMConfigDB.id.desc() - ).limit(1) + stmt = stmt.order_by(UserLLMConfigDB.id.desc()).limit(1) result = await self._db.execute(stmt) cfg = result.scalars().first() if cfg is None: return None - provider = cfg.provider - config = { - "id": cfg.id, - "user_id": cfg.user_id, - "provider_id": cfg.provider_id, - "provider_type": provider.provider_type if provider else None, - "provider_name": provider.provider_name if provider else None, - "config_name": cfg.config_name, - "api_key": cfg.api_key, - "custom_api_base_url": cfg.custom_api_base_url, - "model_name": cfg.model_name, - "priority": cfg.priority, - "is_active": cfg.is_active, - "is_default": cfg.is_default, - "timeout_ms": cfg.timeout_ms, - "max_retries": cfg.max_retries, - "stream_enabled": cfg.stream_enabled, - "extra_config": _parse_json_field(cfg.extra_config), - "capability": cfg.capability, - } + config = _user_config_to_dict(cfg) # 回填缓存 if use_cache: @@ -298,7 +191,6 @@ async def get_user_config_by_id( stmt = ( select(UserLLMConfigDB) - .options(selectinload(UserLLMConfigDB.provider)) .where(UserLLMConfigDB.id == config_id) .where(UserLLMConfigDB.user_id == user_id) .where(UserLLMConfigDB.is_active == True) @@ -309,26 +201,7 @@ async def get_user_config_by_id( if cfg is None: return None - provider = cfg.provider - config = { - "id": cfg.id, - "user_id": cfg.user_id, - "provider_id": cfg.provider_id, - "provider_type": provider.provider_type if provider else None, - "provider_name": provider.provider_name if provider else None, - "config_name": cfg.config_name, - "api_key": cfg.api_key, - "custom_api_base_url": cfg.custom_api_base_url, - "model_name": cfg.model_name, - "priority": cfg.priority, - "is_active": cfg.is_active, - "is_default": cfg.is_default, - "timeout_ms": cfg.timeout_ms, - "max_retries": cfg.max_retries, - "stream_enabled": cfg.stream_enabled, - "extra_config": _parse_json_field(cfg.extra_config), - "capability": cfg.capability, # 新增字段 - } + config = _user_config_to_dict(cfg) # 回填缓存 if use_cache: @@ -337,10 +210,7 @@ async def get_user_config_by_id( return config async def get_user_configs_by_capability( - self, - user_id: int, - capability: str, - use_cache: bool = True + self, user_id: int, capability: str, use_cache: bool = True ) -> List[Dict[str, Any]]: """获取用户指定能力的所有配置 @@ -352,7 +222,8 @@ async def get_user_configs_by_capability( Returns: 该能力的所有配置列表 """ - cache_key = f"llm:user:{user_id}:configs:{capability}" + capability_upper = capability.upper() + cache_key = f"llm:user:{user_id}:configs:{capability_upper}" # 先查缓存 if use_cache: @@ -366,37 +237,15 @@ async def get_user_configs_by_capability( stmt = ( select(UserLLMConfigDB) - .options(selectinload(UserLLMConfigDB.provider)) .where(UserLLMConfigDB.user_id == user_id) - .where(UserLLMConfigDB.capability == capability.upper()) + .where(UserLLMConfigDB.capability == capability_upper) .where(UserLLMConfigDB.is_active == True) - .order_by(UserLLMConfigDB.priority.desc()) + .order_by(UserLLMConfigDB.id.desc()) ) result = await self._db.execute(stmt) configs_db = result.scalars().all() - configs = [] - for cfg in configs_db: - provider = cfg.provider - configs.append({ - "id": cfg.id, - "user_id": cfg.user_id, - "provider_id": cfg.provider_id, - "provider_type": provider.provider_type if provider else None, - "provider_name": provider.provider_name if provider else None, - "config_name": cfg.config_name, - "api_key": cfg.api_key, - "custom_api_base_url": cfg.custom_api_base_url, - "model_name": cfg.model_name, - "priority": cfg.priority, - "is_active": cfg.is_active, - "is_default": cfg.is_default, - "timeout_ms": cfg.timeout_ms, - "max_retries": cfg.max_retries, - "stream_enabled": cfg.stream_enabled, - "extra_config": _parse_json_field(cfg.extra_config), - "capability": cfg.capability, - }) + configs = [_user_config_to_dict(cfg) for cfg in configs_db] # 回填缓存 if use_cache: @@ -431,7 +280,11 @@ async def get_system_providers( if self._db is None: return [] - stmt = select(SystemProviderDB).where(SystemProviderDB.is_active == True) + stmt = ( + select(SystemProviderDB) + .options(selectinload(SystemProviderDB.provider_models)) + .where(SystemProviderDB.is_active == True) + ) if provider_type: stmt = stmt.where(SystemProviderDB.provider_type == provider_type) stmt = stmt.order_by(SystemProviderDB.priority.desc()) @@ -441,16 +294,22 @@ async def get_system_providers( providers = [] for p in providers_db: - providers.append({ - "id": p.id, - "provider_type": p.provider_type, - "provider_name": p.provider_name, - "api_base_url": p.api_base_url, - "supported_models": _parse_json_field(p.supported_models) or {}, - "config_schema": _parse_json_field(p.config_schema), - "is_active": p.is_active, - "priority": p.priority, - }) + models: Dict[str, List[str]] = {} + for model in p.provider_models: + if not model.is_active: + continue + models.setdefault(model.model_name, []).append(model.capability) + providers.append( + { + "id": p.id, + "provider_type": p.provider_type, + "provider_name": p.provider_name, + "api_base_url": p.api_base_url, + "models": models, + "is_active": p.is_active, + "priority": p.priority, + } + ) # 回填缓存 if use_cache: @@ -470,7 +329,9 @@ async def get_system_provider_by_type( Returns: 厂商详情 """ - providers = await self.get_system_providers(provider_type=provider_type, use_cache=use_cache) + providers = await self.get_system_providers( + provider_type=provider_type, use_cache=use_cache + ) return providers[0] if providers else None async def clear_cache(self, user_id: Optional[str] = None) -> None: @@ -501,10 +362,10 @@ async def decrypt_api_key(self, encrypted_key: str) -> str: def get_system_fallback_config_by_capability(self, capability: str) -> Optional[Dict[str, Any]]: """获取从系统环境变量中读取的兜底 LLM 配置""" from src.config import settings - + if not settings.SYSTEM_LLM_API_KEY: return None - + model_name = None cap_upper = capability.upper() if cap_upper == "CHAT": @@ -515,27 +376,21 @@ def get_system_fallback_config_by_capability(self, capability: str) -> Optional[ model_name = settings.SYSTEM_LLM_MODEL_RERANK elif cap_upper in ["VISION", "OCR"]: model_name = settings.SYSTEM_LLM_MODEL_VISION - + if not model_name: return None - + return { "id": "system-default", "user_id": "system", "provider_id": "system", "provider_type": settings.SYSTEM_LLM_PROVIDER, - "provider_name": "System Default", - "config_name": f"System Default {cap_upper}", "api_key": settings.SYSTEM_LLM_API_KEY, - "custom_api_base_url": settings.SYSTEM_LLM_API_BASE, + "api_base_url": settings.SYSTEM_LLM_API_BASE, "model_name": model_name, - "priority": 0, "is_active": True, "is_default": True, - "timeout_ms": 60000, - "max_retries": 3, - "stream_enabled": True, - "extra_config": {}, "capability": cap_upper, - "is_system_fallback": True, # 特殊标识,免于解密 - } \ No newline at end of file + "is_system_preset": False, + "is_system_fallback": True, # 特殊标识,免于解密 + } diff --git a/tests/integration/core/llm/test_system_fallback_integration.py b/tests/integration/core/llm/test_system_fallback_integration.py index 0f4231c..00d3570 100644 --- a/tests/integration/core/llm/test_system_fallback_integration.py +++ b/tests/integration/core/llm/test_system_fallback_integration.py @@ -3,26 +3,29 @@ from src.core.llm.factory import ModelFactory from src.core.llm.interfaces import ITextGenerator, IEmbedder + @pytest.fixture def config_service(): # 测试兜底不需要真正的 DB 连接实例 return ConfigReaderService(db=None) + @pytest.fixture def factory(): return ModelFactory() + @pytest.mark.asyncio async def test_system_fallback_chat(config_service, factory): """测试系统兜底配置加载与 Qwen CHAT 调用""" config = config_service.get_system_fallback_config_by_capability("CHAT") - + assert config is not None, "兜底配置缺失,请检查环境变量(SYSTEM_LLM_API_KEY)" assert config["is_system_fallback"] is True assert config["capability"] == "CHAT" assert config["provider_type"] == "qwen" assert config["model_name"] == "qwen3.5-flash" - + if not config["api_key"] or config["api_key"] == "your_qwen_api_key_here": pytest.skip("检测到使用的是占位 API_KEY,跳过真实的网络请求阶段") @@ -30,45 +33,48 @@ async def test_system_fallback_chat(config_service, factory): client = factory.create_client( provider_type=config["provider_type"], api_key=config["api_key"], - api_base_url=config["custom_api_base_url"], - model_name=config["model_name"] + api_base_url=config["api_base_url"], + model_name=config["model_name"], ) - + # [真实网络请求] 发起生成调用 print(f"\n>>> 发起 CHAT 请求 (Model: {config['model_name']})") result = await client.generate(prompt="你好,请只用 10 个字介绍你自己。") - + print(f"\n<<< 收到响应: {result.content}") print(f"<<< 用量统计: {result.usage}") assert result.content is not None assert result.usage.total_tokens > 0 + @pytest.mark.asyncio async def test_system_fallback_embedding(config_service, factory): """测试系统兜底配置加载与 Qwen EMBEDDING 调用""" config = config_service.get_system_fallback_config_by_capability("EMBEDDING") - + assert config is not None assert config["capability"] == "EMBEDDING" - + if not config["api_key"] or config["api_key"] == "your_qwen_api_key_here": pytest.skip("检测到使用的是占位 API_KEY,跳过真实的网络请求阶段") client = factory.create_client( provider_type=config["provider_type"], api_key=config["api_key"], - api_base_url=config["custom_api_base_url"], - model_name=config["model_name"] + api_base_url=config["api_base_url"], + model_name=config["model_name"], ) - + # [真实网络请求] 发起向量化调用 print(f"\n>>> 发起 EMBEDDING 请求 (Model: {config['model_name']})") - result = await client.embed(texts=["这是一个用于测试的自然语言问句,请转化为向量分布"], model=config["model_name"]) - + result = await client.embed( + texts=["这是一个用于测试的自然语言问句,请转化为向量分布"], model=config["model_name"] + ) + vectors = result.embeddings print(f"\n<<< 收到响应: 共 {len(vectors)} 条向量") print(f"<<< 向量唯度: {len(vectors[0])} d") print(f"<<< 用量统计: {result.usage}") - + assert len(vectors) == 1 assert len(vectors[0]) > 0 # 例如通常是 1024, 1536 等 diff --git a/tests/integration/services/test_config_reader_integration.py b/tests/integration/services/test_config_reader_integration.py index 6eae3cd..2d502c1 100644 --- a/tests/integration/services/test_config_reader_integration.py +++ b/tests/integration/services/test_config_reader_integration.py @@ -6,8 +6,8 @@ - 测试时注入 NullCacheBackend,不依赖 Redis - 生产时使用 RedisCacheBackend """ + import asyncio -import json import time import uuid import pytest @@ -26,6 +26,7 @@ def create_unique_user_id(): """生成唯一的用户 ID""" import random + return int(f"2{random.randint(100000000, 999999999)}") # 以2开头的10位数字 @@ -60,6 +61,7 @@ def create_unique_provider_type(): def reset_db_engine(): """每个测试前重置数据库引擎连接池""" import src.database as db_module + if db_module._async_engine is not None: try: loop = asyncio.get_event_loop() @@ -118,43 +120,55 @@ async def setup_test_data(self, db_session: AsyncSession): try: with conn.cursor() as cursor: # 插入测试 SystemProvider - cursor.execute(""" + cursor.execute( + """ INSERT INTO llm_system_provider - (provider_type, provider_name, api_base_url, supported_models, config_schema, is_active, priority) - VALUES (%s, %s, %s, %s, %s, %s, %s) - """, ( - provider_type, - "OpenAI Test", - "https://api.openai.com/v1", - json.dumps({"gpt-4": ["CHAT", "OCR"], "gpt-3.5-turbo": ["CHAT"]}), - json.dumps({"temperature": {"type": "float", "default": 0.7}}), - 1, - 100 - )) + (provider_type, provider_name, api_base_url, is_active, priority) + VALUES (%s, %s, %s, %s, %s) + """, + (provider_type, "OpenAI Test", "https://api.openai.com/v1", 1, 100), + ) test_ids["provider_id"] = cursor.lastrowid test_ids["provider_type"] = provider_type - # 插入测试 UserLLMConfig(带 capability 字段) - cursor.execute(""" + cursor.execute( + """ + INSERT INTO llm_provider_model + (provider_id, model_name, capability, is_active) + VALUES (%s, %s, %s, %s), (%s, %s, %s, %s) + """, + ( + test_ids["provider_id"], + "gpt-4", + "CHAT", + 1, + test_ids["provider_id"], + "gpt-4", + "OCR", + 1, + ), + ) + + # 插入测试 UserLLMConfig(Java 写入后的运行时生效结构) + cursor.execute( + """ INSERT INTO llm_user_config - (user_id, provider_id, provider_type, provider_name, config_name, api_key, model_name, priority, is_active, is_default, timeout_ms, max_retries, stream_enabled, capability) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, ( - test_user_id, - test_ids["provider_id"], - provider_type, - "OpenAI Test", - "Test GPT-4 Config", - "encrypted_test_key", - "gpt-4", - 50, - 1, - 1, - 60000, - 3, - 1, - "CHAT" # 新增 capability 字段 - )) + (user_id, provider_id, provider_type, api_key, api_base_url, model_name, capability, is_active, is_default, is_system_preset) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + ( + test_user_id, + test_ids["provider_id"], + provider_type, + "encrypted_test_key", + "https://api.openai.com/v1", + "gpt-4", + "CHAT", + 1, + 1, + 0, + ), + ) test_ids["config_id"] = cursor.lastrowid test_ids["user_id"] = test_user_id finally: @@ -168,14 +182,18 @@ async def setup_test_data(self, db_session: AsyncSession): try: with conn.cursor() as cursor: cursor.execute(f"DELETE FROM llm_user_config WHERE id = {test_ids['config_id']}") - cursor.execute(f"DELETE FROM llm_system_provider WHERE id = {test_ids['provider_id']}") + cursor.execute( + f"DELETE FROM llm_provider_model WHERE provider_id = {test_ids['provider_id']}" + ) + cursor.execute( + f"DELETE FROM llm_system_provider WHERE id = {test_ids['provider_id']}" + ) finally: conn.close() @pytest_asyncio.fixture async def setup_multi_capability_test_data(self, db_session: AsyncSession): """创建多种 capability 的测试数据""" - import random provider_type1 = create_unique_provider_type() provider_type2 = f"anthropic_test_{uuid.uuid4().hex[:8]}" test_user_id = create_unique_user_id() @@ -185,101 +203,115 @@ async def setup_multi_capability_test_data(self, db_session: AsyncSession): try: with conn.cursor() as cursor: # 插入两个测试 SystemProvider - cursor.execute(""" + cursor.execute( + """ INSERT INTO llm_system_provider - (provider_type, provider_name, api_base_url, supported_models, is_active, priority) - VALUES (%s, %s, %s, %s, %s, %s) - """, ( - provider_type1, - "OpenAI Test", - "https://api.openai.com/v1", - json.dumps({"gpt-4": ["CHAT"], "text-embedding-3": ["EMBEDDING"]}), - 1, - 100 - )) + (provider_type, provider_name, api_base_url, is_active, priority) + VALUES (%s, %s, %s, %s, %s) + """, + (provider_type1, "OpenAI Test", "https://api.openai.com/v1", 1, 100), + ) provider_id1 = cursor.lastrowid - cursor.execute(""" + cursor.execute( + """ INSERT INTO llm_system_provider - (provider_type, provider_name, api_base_url, supported_models, is_active, priority) - VALUES (%s, %s, %s, %s, %s, %s) - """, ( - provider_type2, - "Anthropic Test", - "https://api.anthropic.com", - json.dumps({"claude-3": ["CHAT", "VISION"]}), - 1, - 90 - )) + (provider_type, provider_name, api_base_url, is_active, priority) + VALUES (%s, %s, %s, %s, %s) + """, + (provider_type2, "Anthropic Test", "https://api.anthropic.com", 1, 90), + ) provider_id2 = cursor.lastrowid + cursor.execute( + """ + INSERT INTO llm_provider_model + (provider_id, model_name, capability, is_active) + VALUES + (%s, %s, %s, %s), + (%s, %s, %s, %s), + (%s, %s, %s, %s) + """, + ( + provider_id1, + "gpt-4", + "CHAT", + 1, + provider_id1, + "text-embedding-3", + "EMBEDDING", + 1, + provider_id2, + "claude-3-rerank", + "RERANK", + 1, + ), + ) + # 插入 CHAT 配置(默认) - cursor.execute(""" + cursor.execute( + """ INSERT INTO llm_user_config - (user_id, provider_id, provider_type, provider_name, config_name, api_key, model_name, priority, is_active, is_default, timeout_ms, max_retries, stream_enabled, capability) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, ( - test_user_id, - provider_id1, - provider_type1, - "OpenAI Test", - "Chat Config", - "encrypted_key_1", - "gpt-4", - 50, - 1, - 1, # is_default - 60000, - 3, - 1, - "CHAT" - )) + (user_id, provider_id, provider_type, api_key, api_base_url, model_name, capability, is_active, is_default, is_system_preset) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + ( + test_user_id, + provider_id1, + provider_type1, + "encrypted_key_1", + "https://api.openai.com/v1", + "gpt-4", + "CHAT", + 1, + 1, # is_default + 0, + ), + ) chat_config_id = cursor.lastrowid # 插入 EMBEDDING 配置 - cursor.execute(""" + cursor.execute( + """ INSERT INTO llm_user_config - (user_id, provider_id, provider_type, provider_name, config_name, api_key, model_name, priority, is_active, is_default, timeout_ms, max_retries, stream_enabled, capability) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, ( - test_user_id, - provider_id1, - provider_type1, - "OpenAI Test", - "Embedding Config", - "encrypted_key_2", - "text-embedding-3", - 40, - 1, - 1, # is_default for EMBEDDING - 60000, - 3, - 1, - "EMBEDDING" - )) + (user_id, provider_id, provider_type, api_key, api_base_url, model_name, capability, is_active, is_default, is_system_preset) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + ( + test_user_id, + provider_id1, + provider_type1, + "encrypted_key_2", + "https://api.openai.com/v1", + "text-embedding-3", + "EMBEDDING", + 1, + 1, # is_default for EMBEDDING + 0, + ), + ) embedding_config_id = cursor.lastrowid # 插入 RERANK 配置 - cursor.execute(""" + cursor.execute( + """ INSERT INTO llm_user_config - (user_id, provider_id, provider_type, provider_name, config_name, api_key, model_name, priority, is_active, is_default, timeout_ms, max_retries, stream_enabled, capability) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, ( - test_user_id, - provider_id2, - provider_type2, - "Anthropic Test", - "Rerank Config", - "encrypted_key_3", - "claude-3-rerank", - 30, - 1, - 1, # is_default for RERANK - 60000, - 3, - 1, - "RERANK" - )) + (user_id, provider_id, provider_type, api_key, api_base_url, model_name, capability, is_active, is_default, is_system_preset) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + ( + test_user_id, + provider_id2, + provider_type2, + "encrypted_key_3", + "https://api.anthropic.com", + "claude-3-rerank", + "RERANK", + 1, + 1, # is_default for RERANK + 0, + ), + ) rerank_config_id = cursor.lastrowid test_ids = { @@ -303,7 +335,12 @@ async def setup_multi_capability_test_data(self, db_session: AsyncSession): try: with conn.cursor() as cursor: cursor.execute(f"DELETE FROM llm_user_config WHERE user_id = {test_user_id}") - cursor.execute(f"DELETE FROM llm_system_provider WHERE id IN ({provider_id1}, {provider_id2})") + cursor.execute( + f"DELETE FROM llm_provider_model WHERE provider_id IN ({provider_id1}, {provider_id2})" + ) + cursor.execute( + f"DELETE FROM llm_system_provider WHERE id IN ({provider_id1}, {provider_id2})" + ) finally: conn.close() @@ -338,23 +375,21 @@ async def test_GetSystemProviders_FilterByType_Should_Return_Filtered_Providers( assert p["provider_type"] == provider_type @pytest.mark.asyncio - async def test_GetSystemProviders_SupportedModels_Should_Be_Parsed_Correctly( + async def test_GetSystemProviders_Models_Should_Be_Aggregated_Correctly( self, service: ConfigReaderService, setup_test_data ): - """get_system_providers 返回的 supported_models 应正确解析为 dict""" + """get_system_providers 返回的 models 应按模型聚合能力""" providers = await service.get_system_providers() provider_type = setup_test_data["provider_type"] - test_provider = next( - (p for p in providers if p["provider_type"] == provider_type), None - ) + test_provider = next((p for p in providers if p["provider_type"] == provider_type), None) assert test_provider is not None - supported_models = test_provider["supported_models"] - assert isinstance(supported_models, dict) - assert "gpt-4" in supported_models - assert "CHAT" in supported_models["gpt-4"] - assert "OCR" in supported_models["gpt-4"] + models = test_provider["models"] + assert isinstance(models, dict) + assert "gpt-4" in models + assert "CHAT" in models["gpt-4"] + assert "OCR" in models["gpt-4"] @pytest.mark.asyncio async def test_GetSystemProviderByType_Should_Return_Single_Provider( @@ -387,9 +422,7 @@ async def test_GetUserConfigs_Should_Return_User_Configs( assert isinstance(configs, list) assert len(configs) > 0 - test_config = next( - (c for c in configs if c["id"] == setup_test_data["config_id"]), None - ) + test_config = next((c for c in configs if c["id"] == setup_test_data["config_id"]), None) assert test_config is not None, f"测试配置 {setup_test_data['config_id']} 未找到" assert test_config["user_id"] == user_id assert test_config["provider_id"] == setup_test_data["provider_id"] @@ -397,12 +430,12 @@ async def test_GetUserConfigs_Should_Return_User_Configs( assert test_config["is_default"] is True @pytest.mark.asyncio - async def test_GetUserDefaultConfig_Should_Return_Default_Config( + async def test_GetUserDefaultConfigByCapability_Should_Return_Default_Config( self, service: ConfigReaderService, setup_test_data ): - """get_user_default_config 应返回用户的默认配置""" + """get_user_default_config_by_capability 应返回用户该能力默认配置""" user_id = setup_test_data["user_id"] - config = await service.get_user_default_config(user_id) + config = await service.get_user_default_config_by_capability(user_id, "CHAT") assert config is not None assert config["user_id"] == user_id @@ -485,7 +518,10 @@ async def test_GetUserDefaultConfigByCapability_WithProviderType_Should_Filter( user_id, "CHAT", provider_type=setup_multi_capability_test_data["provider_type1"] ) assert config_correct_provider is not None - assert config_correct_provider["provider_type"] == setup_multi_capability_test_data["provider_type1"] + assert ( + config_correct_provider["provider_type"] + == setup_multi_capability_test_data["provider_type1"] + ) @pytest.mark.asyncio async def test_GetUserDefaultConfigByCapability_NonExistent_Should_Return_None( @@ -523,15 +559,15 @@ async def test_GetUserConfigsByCapability_Should_Return_All_Matching_Configs( assert vision_configs == [] @pytest.mark.asyncio - async def test_GetUserConfigsByCapability_Should_Order_By_Priority( + async def test_GetUserConfigsByCapability_Should_Order_By_Id_Desc( self, service: ConfigReaderService, setup_multi_capability_test_data: dict ): - """get_user_configs_by_capability 返回的配置应按优先级降序排列""" + """get_user_configs_by_capability 返回的配置应按 id 降序排列""" user_id = setup_multi_capability_test_data["user_id"] configs = await service.get_user_configs_by_capability(user_id, "CHAT") - # 验证按优先级降序排列 + # priority 已由 Java 侧配置模型移除;读取侧只保证稳定的 id 降序。 if len(configs) > 1: for i in range(len(configs) - 1): - assert configs[i]["priority"] >= configs[i + 1]["priority"] + assert configs[i]["id"] >= configs[i + 1]["id"] diff --git a/tests/unit/core/llm/test_encryption.py b/tests/unit/core/llm/test_encryption.py new file mode 100644 index 0000000..6fe65c8 --- /dev/null +++ b/tests/unit/core/llm/test_encryption.py @@ -0,0 +1,20 @@ +import pytest + +from src.config import settings +from src.core.llm.encryption import decrypt_api_key, encrypt_api_key + + +def test_encrypt_decrypt_api_key_with_java_compatible_hex_secret(monkeypatch): + monkeypatch.setattr(settings, "API_KEY_ENCRYPTION_SECRET", "01" * 32) + + encrypted = encrypt_api_key("sk-test-value") + + assert encrypted != "sk-test-value" + assert decrypt_api_key(encrypted) == "sk-test-value" + + +def test_encryption_secret_must_be_64_hex_chars(monkeypatch): + monkeypatch.setattr(settings, "API_KEY_ENCRYPTION_SECRET", "default-secret") + + with pytest.raises(ValueError, match="64-character hex"): + encrypt_api_key("sk-test-value") diff --git a/tests/unit/core/llm/test_user_model_resolver.py b/tests/unit/core/llm/test_user_model_resolver.py index c018dc7..3a4a2b7 100644 --- a/tests/unit/core/llm/test_user_model_resolver.py +++ b/tests/unit/core/llm/test_user_model_resolver.py @@ -10,6 +10,7 @@ - 能力不支持 → ValueError;能力字符串未知 → ValueError; - override_model 优先级最高。 """ + from __future__ import annotations from unittest.mock import MagicMock @@ -18,6 +19,7 @@ import src.core.llm.user_model_resolver as umr from src.core.llm.exceptions import UserModelConfigMissingError +from src.core.llm.factory import ModelFactory from src.core.llm.interfaces import CapabilityType from src.core.llm.user_model_resolver import ( aresolve_user_model, @@ -65,7 +67,7 @@ def test_build_provider_from_config_user_decrypts(monkeypatch): { "provider_type": "qwen", "api_key": "ENC", - "custom_api_base_url": "https://u/v1", + "api_base_url": "https://u/v1", "model_name": "m-user", }, capability="CHAT", @@ -77,6 +79,31 @@ def test_build_provider_from_config_user_decrypts(monkeypatch): provider.has_capability.assert_called_with(CapabilityType.TEXT) +def test_build_provider_normalizes_java_provider_type_alias(monkeypatch): + captured, _ = _patch_factory(monkeypatch) + rm = build_provider_from_config( + { + "provider_type": "aliyun", + "api_key": "ENC", + "api_base_url": "https://dashscope.example/v1", + "model_name": "qwen-plus", + }, + capability="CHAT", + ) + assert rm.provider_type == "qwen" + assert captured["provider_type"] == "qwen" + + +def test_model_factory_normalizes_provider_type_aliases(): + factory = ModelFactory() + + qwen_client = factory.create_client(provider_type="aliyun", api_key="k") + anthropic_client = factory.create_client(provider_type="claude", api_key="k") + + assert qwen_client.provider_type == "qwen" + assert anthropic_client.provider_type == "anthropic" + + def test_build_provider_from_config_system_fallback_skips_decrypt(monkeypatch): captured, _ = _patch_factory(monkeypatch) rm = build_provider_from_config( diff --git a/tests/unit/core/markdown_parser/test_provider_resolution.py b/tests/unit/core/markdown_parser/test_provider_resolution.py index ac5028d..cf493dc 100644 --- a/tests/unit/core/markdown_parser/test_provider_resolution.py +++ b/tests/unit/core/markdown_parser/test_provider_resolution.py @@ -84,7 +84,7 @@ async def test_table_client_uses_user_chat_config(monkeypatch): config={ "provider_type": "qwen", "api_key": "enc-key", - "custom_api_base_url": "https://user.example.com/v1", + "api_base_url": "https://user.example.com/v1", "model_name": "qwen-max", }, ) @@ -109,7 +109,7 @@ async def test_system_fallback_config_skips_decrypt(monkeypatch): config={ "provider_type": "openai", "api_key": "plain-key", - "custom_api_base_url": None, + "api_base_url": None, "model_name": "gpt-x", "is_system_fallback": True, }, diff --git a/tests/unit/core/splitter/test_user_embedding_resolution.py b/tests/unit/core/splitter/test_user_embedding_resolution.py index 5159640..97801d9 100644 --- a/tests/unit/core/splitter/test_user_embedding_resolution.py +++ b/tests/unit/core/splitter/test_user_embedding_resolution.py @@ -40,9 +40,7 @@ def has_capability(self, _cap): def _patch_session_factory(monkeypatch): - monkeypatch.setattr( - "src.database.get_async_session_factory", lambda: _FakeSessionFactory() - ) + monkeypatch.setattr("src.database.get_async_session_factory", lambda: _FakeSessionFactory()) def _patch_config_reader(monkeypatch, *, config): @@ -94,7 +92,7 @@ async def test_resolve_user_embedding_client_uses_user_config(monkeypatch): config={ "provider_type": "qwen", "api_key": "ENC", - "custom_api_base_url": "https://user.example/v1", + "api_base_url": "https://user.example/v1", "model_name": "user-embed-model", }, ) @@ -121,7 +119,7 @@ async def test_resolve_user_chunk_embedding_pipeline_uses_user_model_and_batch_c config={ "provider_type": "qwen", "api_key": "ENC", - "custom_api_base_url": None, + "api_base_url": None, # DashScope text-embedding-v4 已知单次上限 10 "model_name": "text-embedding-v4", },