diff --git a/.gitignore b/.gitignore index e73a0a9b8..aa0280a2d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore # # Binaries for programs and plugins +bin/ *.exe *.exe~ *.dll diff --git a/Dockerfile.agent b/Dockerfile.agent index 81ac00366..33da774f0 100644 --- a/Dockerfile.agent +++ b/Dockerfile.agent @@ -1,5 +1,5 @@ # 基于已构建的基础镜像 -FROM crpi-6pj79y7ddzdpexs8.cn-hangzhou.personal.cr.aliyuncs.com/wanwulite/agent-base:v0.2.2-0773787a +FROM wanwulite/agent-base:v0.3.3 # 工作目录 WORKDIR /agent/agent_open_source diff --git a/Dockerfile.backend b/Dockerfile.backend index 0df8f2580..a0fffa379 100644 --- a/Dockerfile.backend +++ b/Dockerfile.backend @@ -1,69 +1,29 @@ -ARG WANWU_ARCH - # --- 第一阶段:构建阶段 --- -FROM --platform=linux/$WANWU_ARCH golang:1.24.6-bookworm AS builder -ARG WANWU_ARCH -WORKDIR /app -COPY . . - -RUN make build-tidb-setup-$WANWU_ARCH -RUN make build-bff-$WANWU_ARCH -RUN make build-iam-$WANWU_ARCH -RUN make build-model-$WANWU_ARCH -RUN make build-mcp-$WANWU_ARCH -RUN make build-knowledge-$WANWU_ARCH -RUN make build-rag-$WANWU_ARCH -RUN make build-assistant-$WANWU_ARCH -RUN make build-agent-$WANWU_ARCH -RUN make build-app-$WANWU_ARCH -RUN make build-operate-$WANWU_ARCH - -# --- 第二阶段:运行阶段 --- -FROM --platform=linux/$WANWU_ARCH golang:1.24-alpine -ARG WANWU_ARCH +FROM golang:1.24.6-bookworm AS builder WORKDIR /app +# 复制go.mod和go.sum文件 +COPY go.mod go.sum ./ -COPY configs ./configs - -# tidb-setup -COPY --from=builder /app/bin/$WANWU_ARCH/tidb-setup ./bin/tidb-setup +# 设置GOPROXY +ARG GOPROXY=https://goproxy.cn,direct +ENV GOPROXY=${GOPROXY} -# bff-service -COPY --from=builder /app/bin/$WANWU_ARCH/bff-service ./bin/bff-service -EXPOSE 6668 +# 下载依赖 +RUN go mod download -# iam-servie -COPY --from=builder /app/bin/$WANWU_ARCH/iam-service ./bin/iam-service -EXPOSE 8888 - -# model-servie -COPY --from=builder /app/bin/$WANWU_ARCH/model-service ./bin/model-service -EXPOSE 8989 - -# mcp-servie -COPY --from=builder /app/bin/$WANWU_ARCH/mcp-service ./bin/mcp-service -EXPOSE 9898 - -# knowledge-servie -COPY --from=builder /app/bin/$WANWU_ARCH/knowledge-service ./bin/knowledge-service -EXPOSE 8889 +# 复制源代码 +COPY . . -# rag-servie -COPY --from=builder /app/bin/$WANWU_ARCH/rag-service ./bin/rag-service -EXPOSE 9640 +RUN make build -# assistant-servie -COPY --from=builder /app/bin/$WANWU_ARCH/assistant-service ./bin/assistant-service -EXPOSE 8890 +# --- 第二阶段:运行阶段 --- +FROM alpine:3.23 -# agent-servie -COPY --from=builder /app/bin/$WANWU_ARCH/agent-service ./bin/agent-service -EXPOSE 8990 +WORKDIR /app -# app-servie -COPY --from=builder /app/bin/$WANWU_ARCH/app-service ./bin/app-service -EXPOSE 9988 +ENV ZONEINFO=/zoneinfo.zip +ENV TZ=Asia/Shanghai -# operate-servie -COPY --from=builder /app/bin/$WANWU_ARCH/operate-service ./bin/operate-service -EXPOSE 9797 \ No newline at end of file +COPY configs ./configs +COPY --from=builder /usr/local/go/lib/time/zoneinfo.zip / +COPY --from=builder /app/bin/ ./bin/ diff --git a/Dockerfile.callback b/Dockerfile.callback index 7eb2b0224..6923294b3 100644 --- a/Dockerfile.callback +++ b/Dockerfile.callback @@ -1,5 +1,5 @@ # 基于已构建的基础镜像 -FROM crpi-6pj79y7ddzdpexs8.cn-hangzhou.personal.cr.aliyuncs.com/wanwulite/callback-base:v0.3.0-6456201f +FROM wanwulite/callback-base:v0.3.3 # 工作目录 WORKDIR /callback diff --git a/Dockerfile.frontend b/Dockerfile.frontend index c5d35571a..48a5a08eb 100644 --- a/Dockerfile.frontend +++ b/Dockerfile.frontend @@ -1,20 +1,15 @@ -ARG WANWU_ARCH - # --- 第一阶段:构建阶段 --- -FROM --platform=linux/$WANWU_ARCH node:14 AS builder -ARG WANWU_ARCH +FROM node:14 AS builder WORKDIR /app COPY web . -ENV npm_config_registry=https://registry.npmmirror.com ENV npm_config_unsafe_perm=true - +RUN npm config set registry https://registry.npmmirror.com RUN set -euo && npm install RUN set -euo && npm run build # --- 第二阶段:运行阶段 --- -FROM --platform=linux/$WANWU_ARCH nginx:1.27 -ARG WANWU_ARCH +FROM nginx:1.27 COPY ./configs/middleware/nginx/conf.d /etc/nginx/conf.d diff --git a/Dockerfile.rag b/Dockerfile.rag index 3d50839b2..2b04cb9a5 100644 --- a/Dockerfile.rag +++ b/Dockerfile.rag @@ -1,5 +1,5 @@ # 使用 rag-base 镜像作为基础镜像 -FROM crpi-6pj79y7ddzdpexs8.cn-hangzhou.personal.cr.aliyuncs.com/wanwulite/rag-base:v1.1.1-20251114 +FROM wanwulite/rag-base:v1.1.1 # 设置工作目录 WORKDIR /model_extend diff --git a/Makefile b/Makefile index 843a5e95e..bd2fe3002 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,4 @@ -include .env -include .env.image.${WANWU_ARCH} +WANWU_VERSION := v0.3.3 LDFLAGS := -X main.buildTime=$(shell date +%Y-%m-%d,%H:%M:%S) \ -X main.buildVersion=${WANWU_VERSION} \ @@ -7,74 +6,40 @@ LDFLAGS := -X main.buildTime=$(shell date +%Y-%m-%d,%H:%M:%S) \ -X main.gitBranch=$(shell git --git-dir=./.git for-each-ref --format='%(refname:short)->%(upstream:short)' $(shell git --git-dir=./.git symbolic-ref -q HEAD)) \ -X main.builder=$(shell git config user.name) -build-tidb-setup-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/tidb-setup +build: build-tidb-setup build-bff build-iam build-model build-mcp build-knowledge build-rag build-app build-operate build-assistant build-agent -build-tidb-setup-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/tidb-setup +build-tidb-setup: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/tidb-setup -build-bff-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/bff-service +build-bff: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/bff-service -build-bff-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/bff-service +build-iam: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/iam-service -build-iam-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/iam-service +build-model: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/model-service -build-iam-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/iam-service +build-mcp: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/mcp-service -build-model-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/model-service +build-knowledge: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/knowledge-service -build-model-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/model-service +build-rag: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/rag-service -build-mcp-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/mcp-service +build-app: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/app-service -build-mcp-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/mcp-service +build-operate: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/operate-service -build-knowledge-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/knowledge-service +build-assistant: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/assistant-service -build-knowledge-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/knowledge-service - -build-rag-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/rag-service - -build-rag-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/rag-service - -build-app-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/app-service - -build-app-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/app-service - -build-operate-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/operate-service - -build-operate-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/operate-service - -build-assistant-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/assistant-service - -build-assistant-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/assistant-service - -build-agent-amd64: - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/amd64/ ./cmd/agent-service - -build-agent-arm64: - CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build -mod vendor -ldflags "$(LDFLAGS)" -o ./bin/arm64/ ./cmd/agent-service - -create-docker-net: - docker network create ${WANWU_DOCKER_NETWORK} +build-agent: + CGO_ENABLED=0 go build -ldflags "$(LDFLAGS)" -o ./bin/ ./cmd/agent-service check: go vet ./... @@ -85,9 +50,6 @@ check-callback: docker run --rm -t -v $(PWD)/callback:/callback -w /callback crpi-6pj79y7ddzdpexs8.cn-hangzhou.personal.cr.aliyuncs.com/gromitlee/python:3.12-slim-isort7.0.0 isort --check-only --diff --color . docker run --rm -t -v $(PWD)/callback:/callback -w /callback pyfound/black:25.11.0 black -t py312 --check --diff --color . -doc: - docker run --name golang-swag --privileged=true --rm -v $(PWD):/app -w /app crpi-6pj79y7ddzdpexs8.cn-hangzhou.personal.cr.aliyuncs.com/gromitlee/golang:1.24.6-bookworm-swag1.16.6 bash -c 'make doc-swag' - doc-swag: # swag version v1.16.4 # v1 @@ -103,166 +65,33 @@ doc-swag: swag fmt -g openurl.go -d internal/bff-service/server/http/handler/openurl swag init -g openurl.go -d internal/bff-service/server/http/handler/openurl -o docs/openurl --pd +docker: docker-image-backend docker-image-frontend docker-image-rag docker-image-agent docker-image-callback + +docker-base: docker-image-agent-base docker-image-callback-base + docker-image-backend: - docker build -f Dockerfile.backend --build-arg WANWU_ARCH=${WANWU_ARCH} -t wanwulite/wanwu-backend:${WANWU_VERSION}-$(shell git rev-parse --short HEAD)-${WANWU_ARCH} . + docker build -f Dockerfile.backend -t wanwulite/wanwu-backend:${WANWU_VERSION}-$(shell git rev-parse --short HEAD) . docker-image-frontend: - docker build -f Dockerfile.frontend --build-arg WANWU_ARCH=${WANWU_ARCH} -t wanwulite/wanwu-frontend:${WANWU_VERSION}-$(shell git rev-parse --short HEAD)-${WANWU_ARCH} . + docker build -f Dockerfile.frontend -t wanwulite/wanwu-frontend:${WANWU_VERSION}-$(shell git rev-parse --short HEAD) . docker-image-rag: - docker build -f Dockerfile.rag --build-arg WANWU_ARCH=${WANWU_ARCH} -t wanwulite/rag:${WANWU_VERSION}-$(shell git rev-parse --short HEAD)-${WANWU_ARCH} . + docker build -f Dockerfile.rag -t wanwulite/rag:${WANWU_VERSION}-$(shell git rev-parse --short HEAD) . docker-image-agent: - docker build -f Dockerfile.agent --build-arg WANWU_ARCH=${WANWU_ARCH} -t wanwulite/agent:${WANWU_VERSION}-$(shell git rev-parse --short HEAD)-${WANWU_ARCH} . + docker build -f Dockerfile.agent -t wanwulite/agent:${WANWU_VERSION}-$(shell git rev-parse --short HEAD) . docker-image-agent-base: - docker build -f Dockerfile.agent-base --build-arg WANWU_ARCH=${WANWU_ARCH} -t wanwulite/agent-base:${WANWU_VERSION}-$(shell git rev-parse --short HEAD)-${WANWU_ARCH} . + docker build -f Dockerfile.agent-base -t wanwulite/agent-base:${WANWU_VERSION}-$(shell git rev-parse --short HEAD) . docker-image-callback: - docker build -f Dockerfile.callback --build-arg WANWU_ARCH=${WANWU_ARCH} -t wanwulite/callback:${WANWU_VERSION}-$(shell git rev-parse --short HEAD)-${WANWU_ARCH} . + docker build -f Dockerfile.callback -t wanwulite/callback:${WANWU_VERSION}-$(shell git rev-parse --short HEAD) . docker-image-callback-base: - docker build -f Dockerfile.callback-base --build-arg WANWU_ARCH=${WANWU_ARCH} -t wanwulite/callback-base:${WANWU_VERSION}-$(shell git rev-parse --short HEAD)-${WANWU_ARCH} . + docker build -f Dockerfile.callback-base -t wanwulite/callback-base:${WANWU_VERSION}-$(shell git rev-parse --short HEAD) . grpc-protoc: protoc --proto_path=. --go_out=paths=source_relative:api --go-grpc_out=paths=source_relative:api proto/*/*.proto i18n-jsonl: go test ./pkg/i18n -run TestI18nConvertXlsx2Jsonl - -init: - go mod tidy - go mod vendor - -pb: - docker run --name golang-grpc --privileged=true --rm -v $(PWD):/app -w /app crpi-6pj79y7ddzdpexs8.cn-hangzhou.personal.cr.aliyuncs.com/gromitlee/golang:1.24.6-bookworm-protoc29.4-gengo1.34.1-gengrpc1.5.1-gengw2.20.0-genapi2.20.0 bash -c 'make grpc-protoc' - -# --- mysql --- -run-mysql: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - up -d mysql - -stop-mysql: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - down mysql - -# --- mysql-setup --- -run-mysql-setup: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - up mysql-setup - -stop-mysql-setup: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - down mysql-setup - -# --- tidb --- -run-tidb: - docker-compose -f docker-compose.tidb.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - up -d tidb - -stop-tidb: - docker-compose -f docker-compose.tidb.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - down tidb - -# --- tidb-setup --- -run-tidb-setup: - docker-compose -f docker-compose.tidb.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - up tidb-setup - -stop-tidb-setup: - docker-compose -f docker-compose.tidb.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - down tidb-setup - -# --- oceanbase --- -run-oceanbase: - docker-compose -f docker-compose.oceanbase.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - up -d oceanbase - -stop-oceanbase: - docker-compose -f docker-compose.oceanbase.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - down oceanbase - -# --- redis --- -run-redis: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - up -d redis - -stop-redis: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - down redis - -# --- minio --- -run-minio: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - up -d minio - -stop-minio: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - down minio - -# --- kafka --- -run-kafka: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - up -d kafka - -stop-kafka: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - down kafka - -# --- elastic-setup --- -run-es-setup: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - up -d es-setup - -stop-es-setup: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - down es-setup - -# --- elastic --- -run-es: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - up -d es - -stop-es: - docker-compose -f docker-compose.yaml \ - --env-file .env.image.${WANWU_ARCH} \ - --env-file .env \ - down es \ No newline at end of file diff --git a/async/.gitignore b/async/.gitignore new file mode 100644 index 000000000..20a873070 --- /dev/null +++ b/async/.gitignore @@ -0,0 +1,26 @@ +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +vendor/ + +# Go workspace file +go.work +go.work.sum + +.idea +.vscode +.DS_Store \ No newline at end of file diff --git a/async/LICENSE b/async/LICENSE new file mode 100644 index 000000000..1775d2ce9 --- /dev/null +++ b/async/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 gromitlee + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/async/README.md b/async/README.md new file mode 100644 index 000000000..6406987e4 --- /dev/null +++ b/async/README.md @@ -0,0 +1,109 @@ +# Go Decentralized Async Task Framework +基于golang实现的去中心化异步任务框架 + +**go-async特点** + +- [x] 基于DB管理异步任务,以package代码库的形式引入使用 +- [x] 适应微服务集群多副本的部署模式,每个服务节点中引入的go-async之间,是去中心化的关系 +- [x] 支持用户自定义不同类型异步任务的过程方法 +- [x] 以多用户多任务队列的方式控制异步任务并发执行,支持用户指定异步任务所属队列,支持用户自定义任务队列与队列调度策略 +- [x] 支持在异步任务排队、执行过程中,用户变更异步任务所属队列,暂停、运行、删除异步任务 +- [x] 支持在异步任务排队、执行过程中,引入go-async的服务节点关闭、重启,基于此外部对于异步任务的状态无感 +- [x] Goroutine Safe & Developer Friendly + +## 安装 + +```go +go get github.com/gromitlee/go-async +``` + +## 说明 + +### 设计理念 + +**去中心化** +1. 去中心化,即意味着go-async可以代码库的形式被引入,可以分散在一个集群的一组服务当中形成一套go-async系统;而非必需以单个或一组服务的形式对外提供使用 +2. 一套go-async系统,依赖同一个DB与DB提供的事务,完成分布式的多任务状态管理;即将中心化、一致性的责任托管给DB,go-async本身是去中心化的 + +**系统边界** +1. 一套go-async系统内共用同一个DB,这是不同go-async系统间的边界,一个集群中可以同时存在多套不相干的go-async系统 +2. 一套go-async系统最少可由一个节点构成,一套系统内的不同节点,相互间是平等关系 +3. 一套go-async系统内任务的类型全局一致 + +**任务运行** +1. 一个任务只有在运行中,才会被加载到一个且只有一个节点的内存中,并且该运行中的任务,只受该节点的管理,直到该任务从内存中移除 +2. 节点会以心跳的形式定时更新其上运行着的任务在DB中的状态,当节点发生异常,无法继续维持任务的心跳时,其他节点会发现并标记这些任务运行失败 +3. 当节点正常关闭时,会暂存其上运行着的任务的上下文,标记这些任务为暂停状态并从内存中移除,其他节点会发现并接管继续运行这些任务,这要求业务层在实现各种异步任务时,需要考虑从保存的上下文中正常恢复任务 +4. 暂不考虑节点管理的运行中的任务,托管在其他节点上运行的实现方案 + +### API + +**[go-async API](api.go)** + +```go +func Init(ctx context.Context, db *gorm.DB, options ...AsyncOption) error +func Stop() + +func RegisterTask(taskTyp uint32, newTask async_task.ITaskFunc) error +func CreateTask(ctx context.Context, user, group string, taskTyp uint32, taskCtx string, autoRun bool) (uint32, error) + +func ChangeTaskGroup(ctx context.Context, taskID uint32, group string) error + +func RunTask(ctx context.Context, taskID uint32) error +func DeleteTask(ctx context.Context, taskID uint32) error +func PauseTask(ctx context.Context, taskID uint32) error + +func GetTask(ctx context.Context, taskID uint32) (*async_task.Task, error) +func GetTasks(ctx context.Context, user, group string, taskTypes []uint32, states []async_task.State, offset, limit int32) ([]*async_task.Task, error) +``` + +- `Init`和`Stop`方法用于用户初始化和关闭go-async系统 +- `RegisterTask`和`CreateTask`方法用于用户注册和创建不同类型的异步任务 +- `RunTask`、`DeleteTak`和`PauseTask`方法用于用户运行/继续运行、删除和暂停异步任务 +- `ChangeTaskGroup`方法用于用户变更异步任务所属队列 +- `GetTask`、`GetTasks`方法用于查询异步任务,用户可见的任务状态有 + - `StateInit`已创建 + - `StatePending`排队中 + - `StateRunning`运行中 + - `StateCanceling`取消中 + - `StatePause`暂停 + - `StateFinished`结束 + - `StateFailed`失败 + +**[go-async task API](pkg/async/async_task/task.go)** + +```go +type IReport interface { + Phase() (RunPhase, bool) + Context() string +} + +type ITask interface { + Running(ctx context.Context, taskCtx string, stop <-chan struct{}) <-chan IReport + Deleting(ctx context.Context, taskCtx string, stop <-chan struct{}) <-chan IReport +} +``` + +- `Running`方法用于用户实现完成异步任务业务逻辑的过程方法 + - `ctx`是异步任务运行上下文,由go-async系统初始化时(`Init`方法)指定,不会被go-async系统本身主动取消 + - `taskCtx`是异步任务执行所需的逻辑上下文,由创建异步任务时(`CreateTask`方法)指定,或被异步任务本身动态更新、上报后,由go-async系统回传 + - `stop`用于接收go-async系统的停止信号,可能来自于用户手动暂停(`PauseTask`方法)、删除(`DeleteTask`方法)或go-async系统关闭,`Running`方法运行中最多只会收到一次停止信号 + - `IReport`用于`Running`方法向go-async系统上报异步任务执行情况与逻辑上下文 + - `Context`方法用于向go-async系统上报当前的逻辑上下文,go-async系统负责存储 + - `Phase`方法用于向go-async系统上报异步任务的执行情况 + - `RunPhaseNormal`表示任务正常执行 + - `RunPhaseFinished`表示任务结束,`bool`表示是否删除对应任务记录,之后`IReport`channel应当被关闭并退出执行 + - `RunPhaseFailed`表示任务失败,`bool`表示是否删除对应任务记录,之后`IReport`channel应当被关闭并退出执行 + - 接收到go-async系统`stop`停止信号,任务根据自身情况上报后退出执行即可;一般是上报`RunPhaseNormal` +- `Deleting`方法用于用户实现删除(`DeleteTask`方法)异步任务后,清理业务逻辑的过程方法,与`Running`方法类似 + +## 示例 + +- 示例1:[向量点积异步任务](examples/task_dot_test.go) +- 示例2:[矩阵相乘异步任务(基于向量点积并行计算)](examples/task_mm_test.go) + +## TODO + +- [x] 开放自定义log组件API +- [x] 开放自定义任务队列组件API(可由用户自定义队列调度策略) +- [ ] 加速任务并发启动 \ No newline at end of file diff --git a/async/api.go b/async/api.go new file mode 100644 index 000000000..de104c78e --- /dev/null +++ b/async/api.go @@ -0,0 +1,284 @@ +package async + +import ( + "context" + "errors" + "fmt" + + "gorm.io/gorm" + + "github.com/UnicomAI/wanwu/async/internal/async" + "github.com/UnicomAI/wanwu/async/internal/async/config" + "github.com/UnicomAI/wanwu/async/internal/db/model" + "github.com/UnicomAI/wanwu/async/internal/db/trans" + "github.com/UnicomAI/wanwu/async/internal/tools" + "github.com/UnicomAI/wanwu/async/pkg/async/async_component" + "github.com/UnicomAI/wanwu/async/pkg/async/async_component/pending" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" +) + +var _mgr *async.Mgr + +func Init(ctx context.Context, db *gorm.DB, options ...AsyncOption) error { + if _mgr != nil { + return ErrMgrAlreadyInit + } + var err error + + cfg := config.Config{ + Log: tools.DefaultLog(), + RunMaxConcurrency: 5, + RunTaskInterval: 1, + } + for _, opt := range options { + if cfg, err = opt.apply(cfg); err != nil { + return err + } + } + // pendingRun, pendingDel + if cfg.PendingRun == nil { + cfg.PendingRun = pending.NewPendingRunDefault(db, cfg.Log) + } + if cfg.PendingDel == nil { + cfg.PendingDel = pending.NewPendingDelDefault(db, cfg.Log) + } + + if _mgr, err = async.NewMgr(db, cfg); err != nil { + return err + } + return _mgr.Run(ctx) +} + +func Stop() { + if _mgr == nil { + return + } + _mgr.Stop() + _mgr = nil +} + +func RegisterTask(taskTyp uint32, newTask async_task.ITaskFunc) error { + if _mgr == nil { + return ErrMgrNotInit + } + return _mgr.RegisterTask(taskTyp, newTask) +} + +func CreateTask(ctx context.Context, user, group string, taskTyp uint32, taskCtx string, autoRun bool) (uint32, error) { + if _mgr == nil { + return 0, ErrMgrNotInit + } + return _mgr.CreateTask(ctx, user, group, taskTyp, taskCtx, autoRun) +} + +func ChangeTaskGroup(ctx context.Context, taskID uint32, group string) error { + if _mgr == nil { + return ErrMgrNotInit + } + return _mgr.ChangeTaskGroup(ctx, taskID, group) +} + +func RunTask(ctx context.Context, taskID uint32) error { + if _mgr == nil { + return ErrMgrNotInit + } + return _mgr.UserRun(ctx, taskID) +} + +func DeleteTask(ctx context.Context, taskID uint32) error { + if _mgr == nil { + return ErrMgrNotInit + } + return _mgr.UserDelete(ctx, taskID) +} + +func PauseTask(ctx context.Context, taskID uint32) error { + if _mgr == nil { + return ErrMgrNotInit + } + return _mgr.UserPause(ctx, taskID) +} + +func GetTask(ctx context.Context, taskID uint32) (*async_task.Task, error) { + if _mgr == nil { + return nil, ErrMgrNotInit + } + dbTask, err := _mgr.GetTask(ctx, taskID) + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, ErrTaskNotFound + } + return nil, err + } + return convert(dbTask) +} + +func GetTasks(ctx context.Context, user, group string, taskTypes []uint32, states []async_task.State, offset, limit int32) ([]*async_task.Task, error) { + if _mgr == nil { + return nil, ErrMgrNotInit + } + if len(states) == 0 { + states = append(states, + async_task.StateInit, + async_task.StatePending, + async_task.StateRunning, + async_task.StateCanceling, + async_task.StatePause, + async_task.StateFinished, + async_task.StateFailed) + } + var status []trans.TaskStatus + for _, state := range states { + switch state { + case async_task.StateInit: + status = append(status, trans.TaskStatus{S: trans.TaskStateInit, M: trans.TaskMarkNone}) + case async_task.StatePending: + status = append(status, trans.TaskStatus{S: trans.TaskStatePending, M: trans.TaskMarkRun}) + case async_task.StateRunning: + status = append(status, + trans.TaskStatus{S: trans.TaskStateRunning, M: trans.TaskMarkRun}, + trans.TaskStatus{S: trans.TaskStatePause, M: trans.TaskMarkRun}) + case async_task.StateCanceling: + status = append(status, + trans.TaskStatus{S: trans.TaskStateRunning, M: trans.TaskMarkDelete}, + trans.TaskStatus{S: trans.TaskStateRunning, M: trans.TaskMarkPause}) + case async_task.StatePause: + status = append(status, trans.TaskStatus{S: trans.TaskStatePause, M: trans.TaskMarkPause}) + case async_task.StateFinished: + status = append(status, trans.TaskStatus{S: trans.TaskStateFinished, M: trans.TaskMarkRun}) + case async_task.StateFailed: + status = append(status, trans.TaskStatus{S: trans.TaskStateFailed, M: trans.TaskMarkRun}) + default: + return nil, fmt.Errorf("invalid state (%v)", state) + } + } + dbTasks, err := _mgr.GetTasks(ctx, user, group, taskTypes, status, offset, limit) + if err != nil { + return nil, err + } + var tasks []*async_task.Task + for _, dbTask := range dbTasks { + if t, err := convert(dbTask); err == nil { + tasks = append(tasks, t) + } + } + return tasks, nil +} + +// --- AsyncOption --- + +func WithLogger(logger async_config.Logger) AsyncOption { + return asyncOptionFunc(func(cfg config.Config) (config.Config, error) { + if logger != nil { + cfg.Log = logger + } else { + cfg.Log = tools.EmptyLog() + } + return cfg, nil + }) +} + +func WithRunMaxConcurrency(max int) AsyncOption { + return asyncOptionFunc(func(cfg config.Config) (config.Config, error) { + if max <= 0 { + return cfg, errors.New("invalid run max concurrency") + } + cfg.RunMaxConcurrency = max + return cfg, nil + }) +} + +func WithRunTaskIntervalSecond(interval int) AsyncOption { + return asyncOptionFunc(func(cfg config.Config) (config.Config, error) { + if interval <= 0 { + return cfg, errors.New("invalid run task interval") + } + cfg.RunTaskInterval = interval + return cfg, nil + }) +} + +func WithPendingRunQueue(pendingRun async_component.IQueue) AsyncOption { + return asyncOptionFunc(func(cfg config.Config) (config.Config, error) { + if pendingRun != nil { + cfg.PendingRun = pendingRun + } + return cfg, nil + }) +} + +func WithPendingDelQueue(pendingDel async_component.IQueue) AsyncOption { + return asyncOptionFunc(func(cfg config.Config) (config.Config, error) { + if pendingDel != nil { + cfg.PendingDel = pendingDel + } + return cfg, nil + }) +} + +type AsyncOption interface { + apply(cfg config.Config) (config.Config, error) +} + +type asyncOptionFunc func(cfg config.Config) (config.Config, error) + +func (fn asyncOptionFunc) apply(cfg config.Config) (config.Config, error) { + return fn(cfg) +} + +func convert(dbTask *model.AsyncTask) (*async_task.Task, error) { + if dbTask == nil { + return nil, ErrTaskNotFound + } + var state async_task.State + status := trans.TaskStatus{S: dbTask.State, M: dbTask.Mark} + switch status { + // 任务创建初始状态,不会排队、运行 + case trans.TaskStatus{S: trans.TaskStateInit, M: trans.TaskMarkNone}: + state = async_task.StateInit + // 任务在pending.runQueue中排队,用户可见为排队中 + case trans.TaskStatus{S: trans.TaskStatePending, M: trans.TaskMarkRun}: + state = async_task.StatePending + // 任务在内存中执行run,用户可见为运行中 + case trans.TaskStatus{S: trans.TaskStateRunning, M: trans.TaskMarkRun}: + state = async_task.StateRunning + // 任务在内存中执行run,但用户标记删除,用户可见为取消中 + case trans.TaskStatus{S: trans.TaskStateRunning, M: trans.TaskMarkDelete}: + state = async_task.StateCanceling + // 任务在内存中执行run,用户标记暂停,用户可见为取消中 + case trans.TaskStatus{S: trans.TaskStateRunning, M: trans.TaskMarkPause}: + state = async_task.StateCanceling + // 任务暂停run,但用户可见为运行中 + case trans.TaskStatus{S: trans.TaskStatePause, M: trans.TaskMarkRun}: + state = async_task.StateRunning + // 任务暂停run,用户可见为暂停中 + case trans.TaskStatus{S: trans.TaskStatePause, M: trans.TaskMarkPause}: + state = async_task.StatePause + // 任务执行run结束,用户可见为结束 + case trans.TaskStatus{S: trans.TaskStateFinished, M: trans.TaskMarkRun}: + state = async_task.StateFinished + // 任务执行run失败,用户可见为失败 + case trans.TaskStatus{S: trans.TaskStateFailed, M: trans.TaskMarkRun}: + state = async_task.StateFailed + default: + return nil, ErrTaskNotFound + } + return &async_task.Task{ + ID: dbTask.ID, + User: dbTask.User, + Group: dbTask.Group, + Type: dbTask.Type, + State: state, + CreatedAt: dbTask.CreatedAt, + DoneAt: dbTask.DoneAt, + Ctx: string(dbTask.Ctx), + }, nil +} + +var ( + ErrMgrNotInit = errors.New("async mgr not init") + ErrMgrAlreadyInit = errors.New("async mgr already init") + + ErrTaskNotFound = errors.New("async task not found") +) diff --git a/async/examples/db.go b/async/examples/db.go new file mode 100644 index 000000000..5456b4e5b --- /dev/null +++ b/async/examples/db.go @@ -0,0 +1,89 @@ +package examples + +import ( + "fmt" + "log" + "os" + "time" + + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +type DBType uint8 + +const ( + dbName = "async_test" + + dbMysql DBType = 1 + dbPostgresql DBType = 2 +) + +var _db *gorm.DB + +func getDB(dbTyp DBType, dbName string) *gorm.DB { + if _db != nil { + return _db + } + dbLog := logger.New( + log.New(os.Stdout, "\r\n", log.LstdFlags), + logger.Config{ + SlowThreshold: time.Millisecond * 100, + Colorful: true, + IgnoreRecordNotFoundError: true, + LogLevel: logger.Error, + }, + ) + var db *gorm.DB + var err error + switch dbTyp { + case dbMysql: + db, err = gorm.Open(mysql.Open(fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8&parseTime=%t&loc=%s", + "root", + "gromit1234", + "localhost:3306", + dbName, + true, + // "Asia/Shanghai", + "Local")), &gorm.Config{ + Logger: dbLog, + }) + if err != nil { + break + } + db.Set("gorm:table_options", "ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin") + err = setPoolParam(db, 64, 64) + case dbPostgresql: + db, err = gorm.Open(postgres.Open(fmt.Sprintf("postgres://%s:%s@%s/%s", + "postgres", + "gromit1234", + "localhost:5432", + dbName, + )), &gorm.Config{ + Logger: dbLog, + }) + if err != nil { + break + } + err = setPoolParam(db, 64, 64) + default: + log.Fatal(fmt.Errorf("unknown db type %v", dbTyp)) + } + if err != nil { + log.Fatal(err) + } + _db = db + return _db +} + +func setPoolParam(db *gorm.DB, maxOpenConn, maxIdleConn int) error { + sqlDB, err := db.DB() + if err != nil { + return err + } + sqlDB.SetMaxOpenConns(maxOpenConn) + sqlDB.SetMaxIdleConns(maxIdleConn) + return nil +} diff --git a/async/examples/init.go b/async/examples/init.go new file mode 100644 index 000000000..1725e7357 --- /dev/null +++ b/async/examples/init.go @@ -0,0 +1,27 @@ +package examples + +import ( + "context" + + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" +) + +func asyncInit(del bool, failRate int, options ...async.AsyncOption) error { + // init + if err := async.Init(context.TODO(), getDB(dbMysql, dbName), options...); err != nil { + return err + } + // register + if err := async.RegisterTask(taskTypeDot, func() async_task.ITask { + return &taskDot{del: del, failRate: failRate} + }); err != nil { + return err + } + if err := async.RegisterTask(taskTypeMM, func() async_task.ITask { + return &taskMM{del: del} + }); err != nil { + return err + } + return nil +} diff --git a/async/examples/report.go b/async/examples/report.go new file mode 100644 index 000000000..e8cb3707e --- /dev/null +++ b/async/examples/report.go @@ -0,0 +1,26 @@ +package examples + +import "github.com/UnicomAI/wanwu/async/pkg/async/async_task" + +// report impl IReport +type report struct { + phase async_task.RunPhase + del bool + ctx string +} + +func (r *report) Phase() (async_task.RunPhase, bool) { + return r.phase, r.del +} + +func (r *report) Context() string { + return r.ctx +} + +func (r *report) clone() *report { + return &report{ + phase: r.phase, + del: r.del, + ctx: r.ctx, + } +} diff --git a/async/examples/task_dot.go b/async/examples/task_dot.go new file mode 100644 index 000000000..febf7e5b9 --- /dev/null +++ b/async/examples/task_dot.go @@ -0,0 +1,117 @@ +package examples + +import ( + "context" + "encoding/json" + "math/rand" + "sync" + "time" + + "github.com/UnicomAI/wanwu/async/internal/tools" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" +) + +const ( + taskTypeDot uint32 = 1 +) + +// taskDot 向量点积任务 +// taskCtx json格式:{"I":2,"Sum":14,"A":[1,2,3],"B":[4,5,6]} +// A记录第一个向量,B记录第二个向量,I记录已经计算了第几个元素的乘积(I从1开始,0表示还未计算),Sum记录已经计算的乘积累加和 +type taskDot struct { + wg sync.WaitGroup + + del bool // 是否需要自动清理 + + failRate int // 每次tick fail概率(%),小于等于0不会fail + panicRate int // 每次tick panic概率(%),小于等于0不会panic +} + +type dotCtx struct { + I int + Sum int + A []int + B []int +} + +func (dot *dotCtx) String() string { + b, _ := json.Marshal(dot) + return string(b) +} + +func (t *taskDot) Running(ctx context.Context, taskCtx string, stop <-chan struct{}) <-chan async_task.IReport { + reportCh := make(chan async_task.IReport) + t.wg.Add(1) + go func() { + defer tools.PrintPanicStack() + defer t.wg.Wait() + defer t.wg.Done() + defer close(reportCh) + + r := &report{phase: async_task.RunPhaseNormal, del: t.del, ctx: taskCtx} + defer func() { + reportCh <- r.clone() + }() + + // check + dot := &dotCtx{} + if err := json.Unmarshal([]byte(taskCtx), dot); err != nil { + r.phase = async_task.RunPhaseFailed + return + } + if len(dot.A) != len(dot.B) { + r.phase = async_task.RunPhaseFailed + return + } + + rn := rand.New(rand.NewSource(time.Now().UnixNano())) + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + if dot.I >= len(dot.A) { + r.phase = async_task.RunPhaseFinished + return + } + select { + case <-ctx.Done(): + return + case <-stop: + return + case <-ticker.C: + if rn.Intn(100) < t.panicRate { + panic("task panic test") + } + if rn.Intn(100) < t.failRate { + r.phase = async_task.RunPhaseFailed + return + } + dot.Sum = dot.Sum + dot.A[dot.I]*dot.B[dot.I] + dot.I = dot.I + 1 + r.ctx = dot.String() + reportCh <- r.clone() + } + } + }() + + return reportCh +} + +func (t *taskDot) Deleting(ctx context.Context, taskCtx string, stop <-chan struct{}) <-chan async_task.IReport { + reportCh := make(chan async_task.IReport) + t.wg.Add(1) + go func() { + defer tools.PrintPanicStack() + defer t.wg.Wait() + defer t.wg.Done() + defer close(reportCh) + + select { + case <-ctx.Done(): + return + case <-stop: + return + case reportCh <- &report{phase: async_task.RunPhaseFinished, ctx: taskCtx}: + } + }() + return reportCh +} diff --git a/async/examples/task_dot_test.go b/async/examples/task_dot_test.go new file mode 100644 index 000000000..b0ea1a5a8 --- /dev/null +++ b/async/examples/task_dot_test.go @@ -0,0 +1,313 @@ +package examples + +import ( + "context" + "sync" + "testing" + "time" + + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/internal/tools" + "github.com/UnicomAI/wanwu/async/pkg/async/async_component/pending" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" +) + +var options []async.AsyncOption + +func TestTaskDot_All(t *testing.T) { + TestTaskDot_Default(t) + TestTaskDot_WithPendingRunQueue(t) + TestTaskDot_WithPendingDelQueue(t) + TestTaskDot_WithPendingRunAndDelQueue(t) +} + +func TestTaskDot_Default(t *testing.T) { + TestTaskDot_Finished(t) + TestTaskDot_UserDelete(t) + TestTaskDot_FailedAndUserRestart(t) + TestTaskDot_FailedOrUserPauseAndUserRestart(t) + TestTaskDot_FailedAndDelete(t) +} + +func TestTaskDot_WithPendingRunQueue(t *testing.T) { + pendingRun, err := pending.NewPendingRun(getDB(dbMysql, dbName), tools.DefaultLog()) + if err != nil { + t.Fatal(err) + } + options = []async.AsyncOption{ + async.WithPendingRunQueue(pendingRun), + } + + TestTaskDot_Default(t) +} + +func TestTaskDot_WithPendingDelQueue(t *testing.T) { + pendingDel, err := pending.NewPendingDel(getDB(dbMysql, dbName), tools.DefaultLog()) + if err != nil { + t.Fatal(err) + } + options = []async.AsyncOption{ + async.WithPendingDelQueue(pendingDel), + } + + TestTaskDot_Default(t) +} + +func TestTaskDot_WithPendingRunAndDelQueue(t *testing.T) { + pendingRun, err := pending.NewPendingRun(getDB(dbMysql, dbName), tools.DefaultLog()) + if err != nil { + t.Fatal(err) + } + pendingDel, err := pending.NewPendingDel(getDB(dbMysql, dbName), tools.DefaultLog()) + if err != nil { + t.Fatal(err) + } + options = []async.AsyncOption{ + async.WithPendingRunQueue(pendingRun), + async.WithPendingDelQueue(pendingDel), + } + + TestTaskDot_Default(t) +} + +func TestTaskDot_Finished(t *testing.T) { + // init + if err := asyncInit(false, 0, options...); err != nil { + t.Fatal(err) + } + var wg sync.WaitGroup + // create & run + for i := 0; i < 6; i++ { + wg.Add(1) + go func() { + defer wg.Done() + taskCtx := "{\"A\":[0,1,2,3,4,5,6,7,8,9],\"B\":[0,10,20,30,40,50,60,70,80,90]}" + taskID, err := async.CreateTask(context.TODO(), "", "taskDot", taskTypeDot, taskCtx, true) + if err != nil { + t.Error(err) + return + } + checkTicker := time.NewTicker(time.Millisecond * 123) + defer checkTicker.Stop() + var stop bool + for { + select { + case <-checkTicker.C: + if task, err := async.GetTask(context.TODO(), taskID); err != nil { + t.Error(err) + return + } else if task.State == async_task.StateFinished { + t.Logf("%+v", task) + stop = true + } + } + if stop { + break + } + } + }() + } + // stop + wg.Wait() + async.Stop() +} + +func TestTaskDot_UserDelete(t *testing.T) { + // init + if err := asyncInit(false, 0, options...); err != nil { + t.Fatal(err) + } + var wg sync.WaitGroup + // create & run + for i := 0; i < 6; i++ { + wg.Add(1) + go func() { + defer wg.Done() + taskCtx := "{\"A\":[0,1,2,3,4,5,6,7,8,9],\"B\":[0,10,20,30,40,50,60,70,80,90]}" + taskID, err := async.CreateTask(context.TODO(), "", "taskDot", taskTypeDot, taskCtx, true) + if err != nil { + t.Error(err) + return + } + checkTicker := time.NewTicker(time.Millisecond * 123) + defer checkTicker.Stop() + var stop bool + for { + select { + case <-checkTicker.C: + if task, err := async.GetTask(context.TODO(), taskID); err != nil { + t.Error(err) + return + } else { + switch task.State { + case async_task.StateRunning: + _ = async.DeleteTask(context.TODO(), taskID) + case async_task.StateCanceling: + stop = true + } + } + } + if stop { + break + } + } + time.Sleep(time.Second * 5) + if _, err := async.GetTask(context.TODO(), taskID); err != async.ErrTaskNotFound { + t.Error(err) + return + } + }() + } + // stop + wg.Wait() + async.Stop() +} + +func TestTaskDot_FailedAndUserRestart(t *testing.T) { + // init + if err := asyncInit(false, 30, options...); err != nil { + t.Fatal(err) + } + var wg sync.WaitGroup + // create & run + for i := 0; i < 6; i++ { + wg.Add(1) + go func() { + defer wg.Done() + taskCtx := "{\"A\":[0,1,2,3,4,5,6,7,8,9],\"B\":[0,10,20,30,40,50,60,70,80,90]}" + taskID, err := async.CreateTask(context.TODO(), "", "taskDot", taskTypeDot, taskCtx, true) + if err != nil { + t.Error(err) + return + } + checkTicker := time.NewTicker(time.Millisecond * 123) + defer checkTicker.Stop() + var stop bool + for { + select { + case <-checkTicker.C: + if task, err := async.GetTask(context.TODO(), taskID); err != nil { + t.Error(err) + return + } else { + switch task.State { + case async_task.StateFinished: + t.Logf("%+v", task) + stop = true + case async_task.StateFailed: + if err := async.RunTask(context.TODO(), taskID); err != nil { + t.Error(err) + return + } + } + } + } + if stop { + break + } + } + }() + } + // stop + wg.Wait() + async.Stop() +} + +func TestTaskDot_FailedOrUserPauseAndUserRestart(t *testing.T) { + // init + if err := asyncInit(false, 30, options...); err != nil { + t.Fatal(err) + } + var wg sync.WaitGroup + // create & run + for i := 0; i < 6; i++ { + wg.Add(1) + go func() { + defer wg.Done() + taskCtx := "{\"A\":[0,1,2,3,4,5,6,7,8,9],\"B\":[0,10,20,30,40,50,60,70,80,90]}" + taskID, err := async.CreateTask(context.TODO(), "", "taskDot", taskTypeDot, taskCtx, true) + if err != nil { + t.Error(err) + return + } + checkTicker := time.NewTicker(time.Millisecond * 123) + defer checkTicker.Stop() + pauseTicker := time.NewTicker(time.Second * 5) + defer pauseTicker.Stop() + var stop bool + for { + select { + case <-checkTicker.C: + if task, err := async.GetTask(context.TODO(), taskID); err != nil { + t.Error(err) + return + } else { + switch task.State { + case async_task.StateFinished: + t.Logf("%+v", task) + stop = true + case async_task.StatePause, async_task.StateFailed: + _ = async.RunTask(context.TODO(), taskID) + } + } + case <-pauseTicker.C: + if task, err := async.GetTask(context.TODO(), taskID); err != nil { + t.Error(err) + return + } else if task.State == async_task.StateRunning { + _ = async.PauseTask(context.TODO(), taskID) + } + } + if stop { + break + } + } + }() + } + // stop + wg.Wait() + async.Stop() +} + +func TestTaskDot_FailedAndDelete(t *testing.T) { + // init + if err := asyncInit(true, 50, options...); err != nil { + t.Fatal(err) + } + var wg sync.WaitGroup + // create & run + for i := 0; i < 6; i++ { + wg.Add(1) + go func() { + defer wg.Done() + taskCtx := "{\"A\":[0,1,2,3,4,5,6,7,8,9],\"B\":[0,10,20,30,40,50,60,70,80,90]}" + taskID, err := async.CreateTask(context.TODO(), "", "taskDot", taskTypeDot, taskCtx, true) + if err != nil { + t.Error(err) + return + } + checkTicker := time.NewTicker(time.Millisecond * 123) + defer checkTicker.Stop() + var stop bool + for { + select { + case <-checkTicker.C: + if _, err := async.GetTask(context.TODO(), taskID); err != nil { + if err != async.ErrTaskNotFound { + t.Error(err) + return + } else { + stop = true + } + } + } + if stop { + break + } + } + }() + } + // stop + wg.Wait() + async.Stop() +} diff --git a/async/examples/task_mm.go b/async/examples/task_mm.go new file mode 100644 index 000000000..966c87ec9 --- /dev/null +++ b/async/examples/task_mm.go @@ -0,0 +1,213 @@ +package examples + +import ( + "context" + "encoding/json" + "log" + "sync" + "time" + + async "github.com/UnicomAI/wanwu/async" + async2 "github.com/UnicomAI/wanwu/async/internal/async" + "github.com/UnicomAI/wanwu/async/internal/tools" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" +) + +const ( + taskTypeMM uint32 = 2 +) + +// taskMM 矩阵相乘并行计算,最后计算结果矩阵的元素总和 +// taskCtx json格式:{"Sum":0,"TaskID":123,"A":[[],[],...,[]],"B":[[],[],...,[]],"TaskIDs":[123,456,...]} +// A记录第一个矩阵,B记录第二个矩阵,TaskID记录已经计算了累加和的任务ID,TaskIDs记录所有任务ID +type taskMM struct { + wg sync.WaitGroup + + del bool // 是否需要自动清理 +} + +type mmCtx struct { + Sum int + TaskID int + A [][]int + B [][]int + TaskIDs []int +} + +func (mm *mmCtx) String() string { + b, _ := json.Marshal(mm) + return string(b) +} + +func (t *taskMM) Running(ctx context.Context, taskCtx string, stop <-chan struct{}) <-chan async_task.IReport { + reportCh := make(chan async_task.IReport) + t.wg.Add(1) + go func() { + defer tools.PrintPanicStack() + defer t.wg.Wait() + defer t.wg.Done() + defer close(reportCh) + + r := &report{phase: async_task.RunPhaseNormal, del: t.del, ctx: taskCtx} + defer func() { + reportCh <- r.clone() + }() + + // check + mm := &mmCtx{} + if err := json.Unmarshal([]byte(taskCtx), mm); err != nil { + r.phase = async_task.RunPhaseFailed + return + } + aH, bW, ok := check(mm.A, mm.B) + if !ok { + r.phase = async_task.RunPhaseFailed + return + } + + // create sub taskDot tasks + createTicker := time.NewTicker(time.Millisecond * 100) + defer createTicker.Stop() + if n := len(mm.TaskIDs); n < aH*bW { + for i := 0; i < aH; i++ { + for j := 0; j < bW; j++ { + select { + case <-ctx.Done(): + return + case <-stop: + return + case <-createTicker.C: + var B []int + for _, b := range mm.B { + B = append(B, b[j]) + } + if i*bW+j == n { + dot := &dotCtx{A: mm.A[i], B: B} + if taskID, err := async.CreateTask(ctx, "", "taskDot", taskTypeDot, dot.String(), true); err != nil { + log.Printf("taskMM create sub task err: %v", err) + r.phase = async_task.RunPhaseFailed + return + } else if err := async.ChangeTaskGroup(ctx, taskID, "taskMM_taskDot"); err != nil { + log.Printf("taskMM change sub task %v group err: %v", taskID, err) + r.phase = async_task.RunPhaseFailed + return + } else { + mm.TaskIDs = append(mm.TaskIDs, int(taskID)) + r.ctx = mm.String() + reportCh <- r.clone() + } + n++ + } + } + } + } + } + + // check sub tasks and sum + checkTicker := time.NewTicker(time.Millisecond * 100) + defer checkTicker.Stop() + for _, taskID := range mm.TaskIDs { + if taskID <= mm.TaskID { + continue + } + var next bool + for { + select { + case <-ctx.Done(): + return + case <-stop: + return + case <-checkTicker.C: + task, err := async.GetTask(ctx, uint32(taskID)) + if err != nil { + if err != async2.ErrMgrAlreadyStop { + log.Printf("taskMM get sub task %v err: %v", taskID, err) + r.phase = async_task.RunPhaseFailed + return + } + continue + } + switch task.State { + case async_task.StateFinished: + dot := &dotCtx{} + if err := json.Unmarshal([]byte(task.Ctx), dot); err != nil { + log.Printf("taskMM unmarshal sub task %v ctx err: %v", taskID, err) + r.phase = async_task.RunPhaseFailed + return + } + mm.Sum = mm.Sum + dot.Sum + mm.TaskID = taskID + r.ctx = mm.String() + reportCh <- r.clone() + next = true + case async_task.StateFailed: + if err := async.RunTask(ctx, uint32(taskID)); err != nil { + if err != async2.ErrMgrAlreadyStop { + log.Printf("taskMM restart sub task %v err: %v", taskID, err) + r.phase = async_task.RunPhaseFailed + return + } + } + default: + } + } + if next { + break + } + } + } + // finished + r.phase = async_task.RunPhaseFinished + }() + return reportCh +} + +func (t *taskMM) Deleting(ctx context.Context, taskCtx string, stop <-chan struct{}) <-chan async_task.IReport { + reportCh := make(chan async_task.IReport) + t.wg.Add(1) + go func() { + defer tools.PrintPanicStack() + defer t.wg.Wait() + defer t.wg.Done() + defer close(reportCh) + + select { + case <-ctx.Done(): + return + case <-stop: + return + case reportCh <- &report{phase: async_task.RunPhaseFinished, ctx: taskCtx}: + } + }() + return reportCh +} + +func check(A, B [][]int) (int, int, bool) { + aH := len(A) + if aH == 0 { + return 0, 0, false + } + aW := len(A[0]) + if aW == 0 { + return 0, 0, false + } + for _, a := range A { + if len(a) != aW { + return 0, 0, false + } + } + bH := len(B) + if bH == 0 || bH != aW { + return 0, 0, false + } + bW := len(B[0]) + if bW == 0 { + return 0, 0, false + } + for _, b := range B { + if len(b) != bW { + return 0, 0, false + } + } + return aH, bW, true +} diff --git a/async/examples/task_mm_test.go b/async/examples/task_mm_test.go new file mode 100644 index 000000000..5243d9f60 --- /dev/null +++ b/async/examples/task_mm_test.go @@ -0,0 +1,139 @@ +package examples + +import ( + "context" + "testing" + "time" + + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/internal/tools" + "github.com/UnicomAI/wanwu/async/pkg/async/async_component/pending" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" +) + +func TestTaskMM_All(t *testing.T) { + TestTaskMM_Default(t) + TestTaskMM_WithPendingRunQueue(t) + TestTaskMM_WithPendingDelQueue(t) + TestTaskMM_WithPendingRunAndDelQueue(t) +} + +func TestTaskMM_Default(t *testing.T) { + TestTaskMM_Finished(t) + TestTaskMM_SysPauseAndRestart(t) +} + +func TestTaskMM_WithPendingRunQueue(t *testing.T) { + pendingRun, err := pending.NewPendingRun(getDB(dbMysql, dbName), tools.DefaultLog()) + if err != nil { + t.Fatal(err) + } + options = []async.AsyncOption{ + async.WithPendingRunQueue(pendingRun), + } + + TestTaskMM_Default(t) +} + +func TestTaskMM_WithPendingDelQueue(t *testing.T) { + pendingDel, err := pending.NewPendingDel(getDB(dbMysql, dbName), tools.DefaultLog()) + if err != nil { + t.Fatal(err) + } + options = []async.AsyncOption{ + async.WithPendingDelQueue(pendingDel), + } + + TestTaskMM_Default(t) +} + +func TestTaskMM_WithPendingRunAndDelQueue(t *testing.T) { + pendingRun, err := pending.NewPendingRun(getDB(dbMysql, dbName), tools.DefaultLog()) + if err != nil { + t.Fatal(err) + } + pendingDel, err := pending.NewPendingDel(getDB(dbMysql, dbName), tools.DefaultLog()) + if err != nil { + t.Fatal(err) + } + options = []async.AsyncOption{ + async.WithPendingRunQueue(pendingRun), + async.WithPendingDelQueue(pendingDel), + } + + TestTaskMM_Default(t) +} + +func TestTaskMM_Finished(t *testing.T) { + // init + if err := asyncInit(false, 30, options...); err != nil { + t.Fatal(err) + } + // create & run + taskCtx := "{\"A\":[[0,1,2,3,4],[5,6,7,8,9],[10,11,12,13,14],[15,16,17,18,19],[20,21,22,23,24]]," + + "\"B\":[[0,1,2],[3,4,5],[6,7,8],[9,10,11],[12,13,14]]}" + taskID, err := async.CreateTask(context.TODO(), "", "taskMM", taskTypeMM, taskCtx, true) + if err != nil { + t.Fatal(err) + } + checkTicker := time.NewTicker(123 * time.Millisecond) + defer checkTicker.Stop() + var stop bool + for { + select { + case <-checkTicker.C: + if task, err := async.GetTask(context.TODO(), taskID); err != nil { + t.Fatal(err) + } else if task.State == async_task.StateFinished { + t.Logf("%+v", task) + stop = true + } + } + if stop { + break + } + } + // stop + async.Stop() +} + +func TestTaskMM_SysPauseAndRestart(t *testing.T) { + // init + if err := asyncInit(false, 0, options...); err != nil { + t.Fatal(err) + } + // create & run + taskCtx := "{\"A\":[[0,1,2,3,4],[5,6,7,8,9],[10,11,12,13,14],[15,16,17,18,19],[20,21,22,23,24]]," + + "\"B\":[[0,1,2],[3,4,5],[6,7,8],[9,10,11],[12,13,14]]}" + taskID, err := async.CreateTask(context.TODO(), "", "taskMM", taskTypeMM, taskCtx, true) + if err != nil { + t.Fatal(err) + } + + stopTicker := time.NewTicker(time.Second * 10) + defer stopTicker.Stop() + checkTicker := time.NewTicker(time.Millisecond * 123) + defer checkTicker.Stop() + var stop bool + for { + select { + case <-stopTicker.C: + async.Stop() + if err := asyncInit(false, 0, options...); err != nil { + t.Fatal(err) + } + case <-checkTicker.C: + if task, err := async.GetTask(context.TODO(), taskID); err != nil { + t.Fatal(err) + } else if task.State == async_task.StateFinished { + t.Logf("%+v", task) + stop = true + } + } + if stop { + break + } + } + // stop + async.Stop() +} diff --git a/async/internal/async/config/config.go b/async/internal/async/config/config.go new file mode 100644 index 000000000..cf6621a06 --- /dev/null +++ b/async/internal/async/config/config.go @@ -0,0 +1,23 @@ +package config + +import ( + "github.com/UnicomAI/wanwu/async/pkg/async/async_component" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" +) + +type Config struct { + Log async_config.Logger + PendingRun async_component.IQueue + PendingDel async_component.IQueue + RunMaxConcurrency int + RunTaskInterval int // second +} + +const ( + TaskHeartbeatInterval int = 30 // second + RunCheckInterval int = 1 // second + DeleteMaxConcurrency int = 5 + DeleteTaskInterval int = 3 // second + CleanTimeout int = 120 // second + CleanInterval int = 60 // second +) diff --git a/async/internal/async/deleting/module.go b/async/internal/async/deleting/module.go new file mode 100644 index 000000000..a67aadcb3 --- /dev/null +++ b/async/internal/async/deleting/module.go @@ -0,0 +1,247 @@ +package deleting + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/UnicomAI/wanwu/async/internal/async/config" + "github.com/UnicomAI/wanwu/async/internal/async/task" + "github.com/UnicomAI/wanwu/async/internal/db/client" + "github.com/UnicomAI/wanwu/async/internal/db/trans" + "github.com/UnicomAI/wanwu/async/internal/tools" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" +) + +// IModule goroutine safe +type IModule interface { + Run(ctx context.Context) error + Stop() + + Need() <-chan struct{} + DeleteTask(ctx context.Context, task *task.Task) +} + +type delMod struct { + log async_config.Logger + client client.ITaskClient + + maxConcurrency int + concurrency chan struct{} + + needInterval int // second + need chan struct{} + + heartbeatInterval int // second + + mutex sync.Mutex + tasks sync.Map // taskID -> task + stopped bool + stop chan struct{} + + wg sync.WaitGroup +} + +func NewModule(c client.ITaskClient, cfg config.Config) IModule { + return &delMod{ + log: cfg.Log, + client: c, + maxConcurrency: config.DeleteMaxConcurrency, + concurrency: make(chan struct{}, config.DeleteMaxConcurrency), + needInterval: config.DeleteTaskInterval, + need: make(chan struct{}, 1), + heartbeatInterval: config.TaskHeartbeatInterval, + stop: make(chan struct{}, 1), + } +} + +func (m *delMod) Run(ctx context.Context) error { + m.mutex.Lock() + if m.stopped { + defer m.mutex.Unlock() + return errors.New("async deleting module already stop") + } + m.wg.Add(1) + m.mutex.Unlock() + + go func() { + defer tools.PrintPanicStack() + defer m.wg.Done() + + m.log.Infof("async deleting module run") + needTicker := time.NewTicker(time.Duration(m.needInterval) * time.Second) + defer needTicker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-m.stop: + return + case <-needTicker.C: + if len(m.concurrency) >= m.maxConcurrency { + continue + } + select { + case m.need <- struct{}{}: + default: // do nothing + } + } + } + }() + return nil +} + +func (m *delMod) Stop() { + m.mutex.Lock() + // check stop + if m.stopped { + defer m.mutex.Unlock() + m.log.Errorf("async deleting module already stop") + return + } + // stop + m.stopped = true + m.stop <- struct{}{} + m.tasks.Range(func(_, t interface{}) bool { + if err := t.(*task.Task).SendStop(); err != nil { + m.log.Errorf("async deleting module send stop err: %v", err) + } + return true + }) + m.mutex.Unlock() + // wait + m.wg.Wait() + m.log.Infof("async deleting module stop") +} + +func (m *delMod) Need() <-chan struct{} { + return m.need +} + +func (m *delMod) DeleteTask(ctx context.Context, task *task.Task) { + m.mutex.Lock() + // check stop + if m.stopped { + defer m.mutex.Unlock() + m.log.Errorf("async deleting module delete task %v err: async deleting module already stop", task.ID()) + return + } + // check concurrency + select { + case m.concurrency <- struct{}{}: + default: + defer m.mutex.Unlock() + m.log.Errorf("async deleting module delete task %v err: max concurrency", task.ID()) + return + } + // add task + if _, ok := m.tasks.LoadOrStore(task.ID(), task); ok { + defer m.mutex.Unlock() + m.log.Errorf("async deleting module delete task %v err: already exist", task.ID()) + return + } + m.wg.Add(1) + m.mutex.Unlock() + + go func() { + defer tools.PrintPanicStack() + defer m.wg.Done() + defer func() { + m.mutex.Lock() + defer m.mutex.Unlock() + m.tasks.Delete(task.ID()) + }() + defer func() { + <-m.concurrency + }() + + phase := async_task.RunPhaseNormal + reportCh, err := task.Deleting(ctx) + if err != nil { + m.log.Errorf("async deleting module delete task %v err: %v", task.ID(), err) + return + } + m.log.Debugf("async task %v start deleting", task.ID()) + + ticker := time.NewTicker(time.Duration(m.heartbeatInterval) * time.Second) + defer ticker.Stop() + var stopped bool + for { + select { + case <-ticker.C: + if err := m.client.UpdateHeartbeat(ctx, task.ID()); err != nil { + m.log.Errorf("async task %v deleting heartbeat err: %v", task.ID(), err) + } else { + m.log.Debugf("async task %v deleting heartbeat", task.ID()) + } + case report, ok := <-reportCh: + // check stop + if !ok { + stopped = true + break + } + // check report + if report == nil { + m.log.Errorf("async task %v deleting report nil", task.ID()) + continue + } + currentPhase, needDelete := report.Phase() + eventCtx := report.Context() + m.log.Debugf("async task %v deleting report phase %v event ctx %v", task.ID(), currentPhase, eventCtx) + // update context + if err := m.client.UpdateContext(ctx, task.ID(), eventCtx); err != nil { + m.log.Errorf("async task %v deleting report phase %v event ctx %v err: %v", task.ID(), currentPhase, eventCtx, err) + } + // update state + if currentPhase == async_task.RunPhaseFinished || currentPhase == async_task.RunPhaseFailed { + phase = currentPhase + // delete + if needDelete { + if err := m.client.Delete(ctx, task.ID()); err != nil { + m.log.Errorf("async task %v deleting report phase %v delete err: %v", task.ID(), phase, eventCtx) + } else { + m.log.Debugf("async task %v deleting report phase %v delete", task.ID(), phase) + } + return + } + // update state + var event trans.TaskEvent + switch phase { + case async_task.RunPhaseFinished: + event = trans.EventTaskFinished + case async_task.RunPhaseFailed: + event = trans.EventTaskFailed + } + if err := m.client.TransStatus(ctx, task.ID(), event); err != nil { + m.log.Errorf("async task %v deleting report phase %v trans event %v err: %v", task.ID(), phase, event, err) + } + } + } + if stopped { + break + } + } + + switch phase { + case async_task.RunPhaseNormal: + if task.CheckStop() { + // sys pause + if err := m.client.TransStatus(ctx, task.ID(), trans.EventSysPause); err != nil { + m.log.Errorf("async task %v puase deleting phase normal err: %v", task.ID(), err) + } else { + m.log.Debugf("async task %v pause deleting phase normal", task.ID()) + } + } else { + // task panic, auto failed later + m.log.Errorf("async task %v stop deleting phase normal maybe panic", task.ID()) + } + case async_task.RunPhaseFinished: + m.log.Debugf("async task %v stop deleting phase finished", task.ID()) + case async_task.RunPhaseFailed: + m.log.Errorf("async task %v stop deleting phase failed", task.ID()) + default: + } + }() +} diff --git a/async/internal/async/fixing/cleaner.go b/async/internal/async/fixing/cleaner.go new file mode 100644 index 000000000..e549d46c0 --- /dev/null +++ b/async/internal/async/fixing/cleaner.go @@ -0,0 +1,96 @@ +package fixing + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/UnicomAI/wanwu/async/internal/async/config" + "github.com/UnicomAI/wanwu/async/internal/db/client" + "github.com/UnicomAI/wanwu/async/internal/tools" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" +) + +// IClean goroutine safe +type IClean interface { + Run(ctx context.Context) error + Stop() +} + +type cleaner struct { + log async_config.Logger + client client.ITaskClient + + cleanTimeout int // second + + cleanInterval int // second + + mutex sync.Mutex + stopped bool + stop chan struct{} + + wg sync.WaitGroup +} + +func NewClean(c client.ITaskClient, cfg config.Config) IClean { + return &cleaner{ + log: cfg.Log, + client: c, + cleanTimeout: config.CleanTimeout, + cleanInterval: config.CleanInterval, + stop: make(chan struct{}, 1), + } +} + +func (c *cleaner) Run(ctx context.Context) error { + c.mutex.Lock() + if c.stopped { + defer c.mutex.Unlock() + return errors.New("async cleaner already stop") + } + c.wg.Add(1) + c.mutex.Unlock() + + go func() { + defer tools.PrintPanicStack() + defer c.wg.Done() + + c.log.Infof("async cleaner run") + // clean once when start + if err := c.client.Clean(ctx, c.cleanTimeout); err != nil { + c.log.Errorf("async cleaner err: %v", err) + } + ticker := time.NewTicker(time.Duration(c.cleanInterval) * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-c.stop: + return + case <-ticker.C: + if err := c.client.Clean(ctx, c.cleanTimeout); err != nil { + c.log.Errorf("async cleaner err: %v", err) + } + } + } + }() + return nil +} + +func (c *cleaner) Stop() { + c.mutex.Lock() + // check stop + if c.stopped { + defer c.mutex.Unlock() + c.log.Errorf("async cleaner already stop") + return + } + // stop + c.stopped = true + c.stop <- struct{}{} + // wait + c.wg.Wait() + c.log.Infof("async cleaner stop") +} diff --git a/async/internal/async/mgr.go b/async/internal/async/mgr.go new file mode 100644 index 000000000..4c6749027 --- /dev/null +++ b/async/internal/async/mgr.go @@ -0,0 +1,238 @@ +package async + +import ( + "context" + "errors" + "fmt" + "sync" + + "gorm.io/gorm" + + "github.com/UnicomAI/wanwu/async/internal/async/config" + "github.com/UnicomAI/wanwu/async/internal/async/deleting" + "github.com/UnicomAI/wanwu/async/internal/async/fixing" + "github.com/UnicomAI/wanwu/async/internal/async/running" + "github.com/UnicomAI/wanwu/async/internal/async/task" + "github.com/UnicomAI/wanwu/async/internal/db/client" + "github.com/UnicomAI/wanwu/async/internal/db/model" + "github.com/UnicomAI/wanwu/async/internal/db/trans" + "github.com/UnicomAI/wanwu/async/internal/tools" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" +) + +// Mgr goroutine safe +type Mgr struct { + taskTypes sync.Map // 任务注册 taskType -> newTask func + + log async_config.Logger + client client.ITaskClient + + runningMod running.IModule + deletingMod deleting.IModule + cleaner fixing.IClean + + mutex sync.Mutex + stopped bool + stop chan struct{} + + wg sync.WaitGroup +} + +func NewMgr(db *gorm.DB, cfg config.Config) (*Mgr, error) { + c, err := client.NewClient(db, cfg) + if err != nil { + return nil, err + } + return &Mgr{ + log: cfg.Log, + client: c, + + runningMod: running.NewModule(c, cfg), + deletingMod: deleting.NewModule(c, cfg), + cleaner: fixing.NewClean(c, cfg), + + stop: make(chan struct{}, 1), + }, nil +} + +func (m *Mgr) Run(ctx context.Context) error { + m.mutex.Lock() + if m.stopped { + defer m.mutex.Unlock() + return ErrMgrAlreadyStop + } + m.wg.Add(1) + m.mutex.Unlock() + + go func() { + defer tools.PrintPanicStack() + defer m.wg.Done() + defer m.cleaner.Stop() + defer m.runningMod.Stop() + defer m.deletingMod.Stop() + + m.log.Infof("async mgr run") + // run components + if err := m.cleaner.Run(ctx); err != nil { + m.log.Errorf("async mgr run clear err: %v", err) + return + } + if err := m.runningMod.Run(ctx); err != nil { + m.log.Errorf("async mgr run running module err: %v", err) + return + } + if err := m.deletingMod.Run(ctx); err != nil { + m.log.Errorf("async mgr run deleting module err: %v", err) + return + } + + for { + select { + case <-ctx.Done(): + return + case <-m.stop: + return + case _, ok := <-m.runningMod.Need(): + if !ok { + m.log.Errorf("async mgr running module stop") + return + } + if dbTask, err := m.client.SelectOneRun(ctx, m.getTaskTypes()); err != nil { + m.log.Errorf("async mgr select pending run task err: %v", err) + } else if dbTask != nil { + if task, err := m.newTask(dbTask); err != nil { + m.log.Errorf("async mgr run task err: %v", err) + } else { + m.runningMod.RunTask(ctx, task) + } + } + case _, ok := <-m.deletingMod.Need(): + if !ok { + m.log.Errorf("async mgr deleting module stop") + return + } + if dbTask, err := m.client.SelectOneDelete(ctx, m.getTaskTypes()); err != nil { + m.log.Errorf("async mgr select pending del task err: %v", err) + } else if dbTask != nil { + if task, err := m.newTask(dbTask); err != nil { + m.log.Errorf("async mgr delete task err: %v", err) + } else { + m.deletingMod.DeleteTask(ctx, task) + } + + } + } + } + }() + return nil +} + +func (m *Mgr) Stop() { + m.mutex.Lock() + // check stop + if m.stopped { + defer m.mutex.Unlock() + m.log.Errorf(ErrMgrAlreadyStop.Error()) + return + } + // stop + m.stopped = true + m.stop <- struct{}{} + m.mutex.Unlock() + // wait + m.wg.Wait() + m.log.Infof("async mgr stop") +} + +func (m *Mgr) RegisterTask(taskTyp uint32, newTask async_task.ITaskFunc) error { + if m.checkStop() { + return ErrMgrAlreadyStop + } + if newTask == nil { + return fmt.Errorf("taskTyp %v newTask nil", taskTyp) + } + if _, ok := m.taskTypes.LoadOrStore(taskTyp, newTask); ok { + return fmt.Errorf("taskTyp %v already registered", taskTyp) + } + return nil +} + +func (m *Mgr) CreateTask(ctx context.Context, user, group string, taskTyp uint32, taskCtx string, autoRun bool) (uint32, error) { + if m.checkStop() { + return 0, ErrMgrAlreadyStop + } + if _, ok := m.taskTypes.Load(taskTyp); !ok { + return 0, fmt.Errorf("taskTyp %v not registered", taskTyp) + } + return m.client.CreateTask(ctx, user, group, taskTyp, taskCtx, autoRun) +} + +func (m *Mgr) ChangeTaskGroup(ctx context.Context, taskID uint32, group string) error { + if m.checkStop() { + return ErrMgrAlreadyStop + } + return m.client.ChangeTaskGroup(ctx, taskID, group) +} + +func (m *Mgr) UserRun(ctx context.Context, taskID uint32) error { + if m.checkStop() { + return ErrMgrAlreadyStop + } + return m.client.TransStatus(ctx, taskID, trans.EventUserRun) +} + +func (m *Mgr) UserDelete(ctx context.Context, taskID uint32) error { + if m.checkStop() { + return ErrMgrAlreadyStop + } + return m.client.TransStatus(ctx, taskID, trans.EventUserDelete) +} + +func (m *Mgr) UserPause(ctx context.Context, taskID uint32) error { + if m.checkStop() { + return ErrMgrAlreadyStop + } + return m.client.TransStatus(ctx, taskID, trans.EventUserPause) +} + +func (m *Mgr) GetTask(ctx context.Context, taskID uint32) (*model.AsyncTask, error) { + if m.checkStop() { + return nil, ErrMgrAlreadyStop + } + return m.client.GetTask(ctx, taskID) +} + +func (m *Mgr) GetTasks(ctx context.Context, user, group string, taskTypes []uint32, status []trans.TaskStatus, offset, limit int32) ([]*model.AsyncTask, error) { + if m.checkStop() { + return nil, ErrMgrAlreadyStop + } + return m.client.GetTasks(ctx, user, group, taskTypes, status, offset, limit) +} + +func (m *Mgr) checkStop() bool { + m.mutex.Lock() + defer m.mutex.Unlock() + return m.stopped +} + +func (m *Mgr) getTaskTypes() []uint32 { + var taskTypes []uint32 + m.taskTypes.Range(func(taskType, _ interface{}) bool { + taskTypes = append(taskTypes, taskType.(uint32)) + return true + }) + return taskTypes +} + +func (m *Mgr) newTask(dbTask *model.AsyncTask) (*task.Task, error) { + newTask, ok := m.taskTypes.Load(dbTask.Type) + if !ok { + return nil, fmt.Errorf("async task %v type %v not registered", dbTask.ID, dbTask.Type) + } + return task.NewTask(dbTask, newTask.(async_task.ITaskFunc)(), m.log), nil +} + +var ( + ErrMgrAlreadyStop = errors.New("async mgr already stop") +) diff --git a/async/internal/async/running/module.go b/async/internal/async/running/module.go new file mode 100644 index 000000000..80dc8322e --- /dev/null +++ b/async/internal/async/running/module.go @@ -0,0 +1,266 @@ +package running + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/UnicomAI/wanwu/async/internal/async/config" + "github.com/UnicomAI/wanwu/async/internal/async/task" + "github.com/UnicomAI/wanwu/async/internal/db/client" + "github.com/UnicomAI/wanwu/async/internal/db/trans" + "github.com/UnicomAI/wanwu/async/internal/tools" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" +) + +// IModule goroutine safe +type IModule interface { + Run(ctx context.Context) error + Stop() + + Need() <-chan struct{} + RunTask(ctx context.Context, task *task.Task) +} + +type runMod struct { + log async_config.Logger + client client.ITaskClient + + maxConcurrency int + concurrency chan struct{} + + needInterval int // second + need chan struct{} + + checkInterval int // second + + heartbeatInterval int // second + + mutex sync.Mutex + tasks sync.Map // taskID -> task + stopped bool + stop chan struct{} + + wg sync.WaitGroup +} + +func NewModule(c client.ITaskClient, cfg config.Config) IModule { + return &runMod{ + log: cfg.Log, + client: c, + maxConcurrency: cfg.RunMaxConcurrency, + concurrency: make(chan struct{}, cfg.RunMaxConcurrency), + needInterval: cfg.RunTaskInterval, + need: make(chan struct{}, 1), + checkInterval: config.RunCheckInterval, + heartbeatInterval: config.TaskHeartbeatInterval, + stop: make(chan struct{}, 1), + } +} + +func (m *runMod) Run(ctx context.Context) error { + m.mutex.Lock() + if m.stopped { + defer m.mutex.Unlock() + return errors.New("async running module already stop") + } + m.wg.Add(1) + m.mutex.Unlock() + + go func() { + defer tools.PrintPanicStack() + defer m.wg.Done() + + m.log.Infof("async running module run") + needTicker := time.NewTicker(time.Duration(m.needInterval) * time.Second) + defer needTicker.Stop() + checkTicker := time.NewTicker(time.Duration(m.checkInterval) * time.Second) + defer checkTicker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-m.stop: + return + case <-needTicker.C: + if len(m.concurrency) >= m.maxConcurrency { + continue + } + select { + case m.need <- struct{}{}: + default: // do nothing + } + case <-checkTicker.C: + m.tasks.Range(func(taskID, t interface{}) bool { + if ok, err := m.client.CheckStop(ctx, taskID.(uint32)); err != nil { + m.log.Errorf("async running module check task %v stop err: %v", taskID.(uint32), err) + } else if ok { + if err := t.(*task.Task).SendStop(); err != nil { + m.log.Errorf("async running module send stop err: %v", err) + } + } + return true + }) + } + } + }() + return nil +} + +func (m *runMod) Stop() { + m.mutex.Lock() + // check stop + if m.stopped { + defer m.mutex.Unlock() + m.log.Errorf("async running module already stop") + return + } + // stop + m.stopped = true + m.stop <- struct{}{} + m.tasks.Range(func(_, t interface{}) bool { + if err := t.(*task.Task).SendStop(); err != nil { + m.log.Errorf("async running module send stop err: %v", err) + } + return true + }) + m.mutex.Unlock() + // wait + m.wg.Wait() + m.log.Infof("async running module stop") +} + +func (m *runMod) Need() <-chan struct{} { + return m.need +} + +func (m *runMod) RunTask(ctx context.Context, task *task.Task) { + m.mutex.Lock() + // check stop + if m.stopped { + defer m.mutex.Unlock() + m.log.Errorf("async running module run task %v err: async running module already stop", task.ID()) + return + } + // check concurrency + select { + case m.concurrency <- struct{}{}: + default: + defer m.mutex.Unlock() + m.log.Errorf("async running module run task %v err: max concurrency", task.ID()) + return + } + // add task + if _, ok := m.tasks.LoadOrStore(task.ID(), task); ok { + defer m.mutex.Unlock() + m.log.Errorf("async running module run task %v err: already exist", task.ID()) + return + } + m.wg.Add(1) + m.mutex.Unlock() + + go func() { + defer tools.PrintPanicStack() + defer m.wg.Done() + defer func() { + m.mutex.Lock() + defer m.mutex.Unlock() + m.tasks.Delete(task.ID()) + }() + defer func() { + <-m.concurrency + }() + + phase := async_task.RunPhaseNormal + reportCh, err := task.Running(ctx) + if err != nil { + m.log.Errorf("async running module run task %v err: %v", task.ID(), err) + return + } + m.log.Debugf("async task %v start running, initCtx: %v", task.ID(), task.InitCtx()) + + ticker := time.NewTicker(time.Duration(m.heartbeatInterval) * time.Second) + defer ticker.Stop() + var stopped bool + for { + select { + case <-ticker.C: + if err := m.client.UpdateHeartbeat(ctx, task.ID()); err != nil { + m.log.Errorf("async task %v running heartbeat err: %v", task.ID(), err) + } else { + m.log.Debugf("async task %v running heartbeat", task.ID()) + } + case report, ok := <-reportCh: + // check stop + if !ok { + stopped = true + break + } + // check report + if report == nil { + m.log.Errorf("async task %v running report nil", task.ID()) + continue + } + currentPhase, needDelete := report.Phase() + eventCtx := report.Context() + m.log.Debugf("async task %v running report phase %v event ctx %v", task.ID(), currentPhase, eventCtx) + // update context + if err := m.client.UpdateContext(ctx, task.ID(), eventCtx); err != nil { + m.log.Errorf("async task %v running report phase %v event ctx %v update err: %v", task.ID(), currentPhase, eventCtx, err) + } + // update state + if phase != async_task.RunPhaseNormal { + m.log.Errorf("async task %v running report phase %v err: current phase %v not normal", task.ID(), currentPhase, phase) + } + if currentPhase == async_task.RunPhaseFinished || currentPhase == async_task.RunPhaseFailed { + phase = currentPhase + // delete + if needDelete { + if err := m.client.Delete(ctx, task.ID()); err != nil { + m.log.Errorf("async task %v running report phase %v delete err: %v", task.ID(), phase, err) + } else { + m.log.Debugf("async task %v running report phase %v delete", task.ID(), phase) + } + return + } + // update state + var event trans.TaskEvent + switch phase { + case async_task.RunPhaseFinished: + event = trans.EventTaskFinished + case async_task.RunPhaseFailed: + event = trans.EventTaskFailed + } + if err := m.client.TransStatus(ctx, task.ID(), event); err != nil { + m.log.Errorf("async task %v running report phase %v trans event %v err: %v", task.ID(), phase, event, err) + } + } + } + if stopped { + break + } + } + + switch phase { + case async_task.RunPhaseNormal: + if task.CheckStop() { + // sys pause + if err := m.client.TransStatus(ctx, task.ID(), trans.EventSysPause); err != nil { + m.log.Errorf("async task %v puase running phase normal err: %v", task.ID(), err) + } else { + m.log.Debugf("async task %v pause running phase normal", task.ID()) + } + } else { + // task panic, auto failed later + m.log.Errorf("async task %v stop running phase normal maybe panic", task.ID()) + } + case async_task.RunPhaseFinished: + m.log.Debugf("async task %v stop running phase finished", task.ID()) + case async_task.RunPhaseFailed: + m.log.Errorf("async task %v stop running phase failed", task.ID()) + default: + } + }() +} diff --git a/async/internal/async/task/task.go b/async/internal/async/task/task.go new file mode 100644 index 000000000..541028e69 --- /dev/null +++ b/async/internal/async/task/task.go @@ -0,0 +1,74 @@ +package task + +import ( + "context" + "fmt" + "sync" + + "github.com/UnicomAI/wanwu/async/internal/db/model" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" +) + +// Task 运行中的任务代理 goroutine safe +type Task struct { + taskID uint32 + initCtx string + + task async_task.ITask + log async_config.Logger + + mutex sync.Mutex + stop chan struct{} + stopped bool +} + +func NewTask(dbTask *model.AsyncTask, asyncTask async_task.ITask, log async_config.Logger) *Task { + return &Task{ + taskID: dbTask.ID, + initCtx: string(dbTask.Ctx), + task: asyncTask, + log: log, + stop: make(chan struct{}, 1), + } +} + +func (t *Task) ID() uint32 { + return t.taskID +} + +func (t *Task) InitCtx() string { + return t.initCtx +} + +func (t *Task) Running(ctx context.Context) (<-chan async_task.IReport, error) { + if t.CheckStop() { + return nil, fmt.Errorf("async task %v already send stop event %v", t.taskID, t.stopped) + } + return t.task.Running(ctx, t.initCtx, t.stop), nil +} + +func (t *Task) Deleting(ctx context.Context) (<-chan async_task.IReport, error) { + if t.CheckStop() { + return nil, fmt.Errorf("async task %v already send stop", t.taskID) + } + return t.task.Deleting(ctx, t.initCtx, t.stop), nil +} + +func (t *Task) CheckStop() bool { + t.mutex.Lock() + defer t.mutex.Unlock() + return t.stopped +} + +func (t *Task) SendStop() error { + t.mutex.Lock() + defer t.mutex.Unlock() + if t.stopped { + return fmt.Errorf("async task %v already send stop", t.taskID) + } + t.stopped = true + t.stop <- struct{}{} + t.log.Debugf("async task %v send stop", t.taskID) + return nil +} diff --git a/async/internal/db/client/client.go b/async/internal/db/client/client.go new file mode 100644 index 000000000..080fe28fc --- /dev/null +++ b/async/internal/db/client/client.go @@ -0,0 +1,377 @@ +package client + +import ( + "context" + "fmt" + "time" + + "github.com/UnicomAI/wanwu/pkg/db" + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/UnicomAI/wanwu/async/internal/async/config" + "github.com/UnicomAI/wanwu/async/internal/db/model" + "github.com/UnicomAI/wanwu/async/internal/db/trans" + "github.com/UnicomAI/wanwu/async/pkg/async/async_component" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" +) + +// ITaskClient goroutine safe +type ITaskClient interface { + CreateTask(ctx context.Context, user, group string, taskTyp uint32, taskCtx string, autoRun bool) (uint32, error) + GetTask(ctx context.Context, taskID uint32) (*model.AsyncTask, error) + GetTasks(ctx context.Context, user, group string, taskTypes []uint32, status []trans.TaskStatus, offset, limit int32) ([]*model.AsyncTask, error) + + ChangeTaskGroup(ctx context.Context, taskID uint32, group string) error + + SelectOneRun(ctx context.Context, taskTypes []uint32) (*model.AsyncTask, error) + SelectOneDelete(ctx context.Context, taskTypes []uint32) (*model.AsyncTask, error) + + TransStatus(ctx context.Context, taskID uint32, event trans.TaskEvent) error + + CheckStop(ctx context.Context, taskID uint32) (bool, error) + UpdateHeartbeat(ctx context.Context, taskID uint32) error + UpdateContext(ctx context.Context, taskID uint32, taskCtx string) error + Delete(ctx context.Context, taskID uint32) error + + Clean(ctx context.Context, timeout int) error +} + +type client struct { + log async_config.Logger + db *gorm.DB + + pendingRun async_component.IQueue + pendingDel async_component.IQueue +} + +func NewClient(db *gorm.DB, cfg config.Config) (ITaskClient, error) { + if err := db.AutoMigrate( + model.AsyncTask{}, + ); err != nil { + return nil, err + } + return &client{ + log: cfg.Log, + db: db, + pendingRun: cfg.PendingRun, + pendingDel: cfg.PendingDel, + }, nil +} + +func (c *client) CreateTask(ctx context.Context, user, group string, taskTyp uint32, taskCtx string, autoRun bool) (uint32, error) { + dbTask := &model.AsyncTask{ + User: user, + Group: group, + Type: taskTyp, + State: trans.TaskStateInit, + Mark: trans.TaskMarkNone, + Ctx: db.LongText(taskCtx), + } + if err := c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + if err := tx.Create(dbTask).Error; err != nil { + return err + } + if autoRun { + return c.transStatus(ctx, tx, dbTask.ID, trans.EventUserRun) + } + return nil + }); err != nil { + return 0, err + } + return dbTask.ID, nil +} + +func (c *client) GetTask(ctx context.Context, taskID uint32) (*model.AsyncTask, error) { + dbTask := &model.AsyncTask{} + if err := c.db.WithContext(ctx).Where("id = ?", taskID).First(dbTask).Error; err != nil { + return nil, err + } + return dbTask, nil +} + +func (c *client) GetTasks(ctx context.Context, user, group string, taskTypes []uint32, status []trans.TaskStatus, offset, limit int32) ([]*model.AsyncTask, error) { + db := c.db.WithContext(ctx) + if user != "" { + db = db.Where("user = ?", user) + } + if group != "" { + db = db.Where("`group` = ?", group) + } + if len(taskTypes) > 0 { + db = db.Where("type IN ?", taskTypes) + } + if len(status) > 0 { + var query string + var args []interface{} + for i, s := range status { + if i == 0 { + query = query + "(state = ? AND mark = ?)" + } else { + query = query + " OR (state = ? AND mark = ?)" + } + args = append(args, s.S, s.M) + } + db = db.Where(query, args...) + } + if offset < 0 { + offset = 0 + } + if limit < 0 { + limit = -1 + } + var dbTasks []*model.AsyncTask + if err := db.Offset(int(offset)).Limit(int(limit)).Order("id desc").Find(&dbTasks).Error; err != nil { + return nil, err + } + return dbTasks, nil +} + +func (c *client) ChangeTaskGroup(ctx context.Context, taskID uint32, group string) error { + return c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + dbTask := &model.AsyncTask{} + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ?", taskID).First(dbTask).Error; err != nil { + return err + } + if dbTask.Group == group { + return nil + } + var isPending bool // 用于标记是否需要记录错误日志;pending queue是外部组件,执行错误事务中断可能引起内外数据不一致,需要人工介入 + if dbTask.State == trans.TaskStatePending && dbTask.Mark == trans.TaskMarkRun { + isPending = true + if err := c.pendingRun.DelTask(ctx, taskID); err != nil { + c.log.Errorf("async task %v change group %v -> %v out pending run queue err: %v", taskID, dbTask.Group, group, err) + return err + } + if err := c.pendingRun.AddTask(ctx, dbTask.User, group, taskID, dbTask.Type); err != nil { + c.log.Errorf("async task %v change group %v -> %v in pending run queue err: %v", taskID, dbTask.Group, group, err) + return err + } + } + if dbTask.State == trans.TaskStatePending && dbTask.Mark == trans.TaskMarkDelete { + isPending = true + if err := c.pendingDel.DelTask(ctx, taskID); err != nil { + c.log.Errorf("async task %v change group %v -> %v out pending del queue err: %v", taskID, dbTask.Group, group, err) + return err + } + if err := c.pendingDel.AddTask(ctx, dbTask.User, group, taskID, dbTask.Type); err != nil { + c.log.Errorf("async task %v change group %v -> %v in pending del queue err: %v", taskID, dbTask.Group, group, err) + return err + } + } + if err := tx.Model(&model.AsyncTask{}).Where("id = ?", taskID).Updates(map[string]interface{}{ + "group": group, + }).Error; err != nil { + if isPending { + c.log.Errorf("async task %v change group %v -> %v err: %v", taskID, dbTask.Group, group, err) + } + return err + } + return nil + }) +} + +func (c *client) SelectOneRun(ctx context.Context, taskTypes []uint32) (*model.AsyncTask, error) { + var dbTask *model.AsyncTask + if err := c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var taskID uint32 + // 1. 优先查找 TaskStatus{S: TaskStatePause, M: TaskMarkRun} 的任务 + var dbTasks []*model.AsyncTask + if err := tx.Where("state = ? AND mark = ?", trans.TaskStatePause, trans.TaskMarkRun). + Where("type IN ?", taskTypes).Order("updated_at").Limit(1).Find(&dbTasks).Error; err != nil { + return err + } else if len(dbTasks) > 0 { + taskID = dbTasks[0].ID + } + // 2. 再查找pending.runQueue中的任务 + if taskID == 0 { + if pendingID, err := c.pendingRun.PopOne(ctx, taskTypes); err != nil { + return err + } else if pendingID == 0 { + return nil + } else { + taskID = pendingID + } + } + + if err := c.transStatus(ctx, tx, taskID, trans.EventSysExecute); err != nil { + return err + } + dbTask = &model.AsyncTask{} + return tx.Where("id = ?", taskID).First(dbTask).Error + }); err != nil { + return nil, err + } + return dbTask, nil +} + +func (c *client) SelectOneDelete(ctx context.Context, taskTypes []uint32) (*model.AsyncTask, error) { + var dbTask *model.AsyncTask + if err := c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var taskID uint32 + // 1. 优先查找 TaskStatus{S: TaskStatePause, M: TaskMarkDelete} 的任务 + var dbTasks []*model.AsyncTask + if err := tx.Where("state = ? AND mark = ?", trans.TaskStatePause, trans.TaskMarkDelete). + Where("type IN ?", taskTypes).Order("updated_at").Limit(1).Find(&dbTasks).Error; err != nil { + return err + } else if len(dbTasks) > 0 { + taskID = dbTasks[0].ID + } + // 2. 再查找pending.deleteQueue中的任务 + if taskID == 0 { + if pendingID, err := c.pendingDel.PopOne(ctx, taskTypes); err != nil { + return err + } else if pendingID == 0 { + return nil + } else { + taskID = pendingID + } + } + + if err := c.transStatus(ctx, tx, taskID, trans.EventSysExecute); err != nil { + return err + } + dbTask = &model.AsyncTask{} + return tx.Where("id = ?", taskID).First(dbTask).Error + }); err != nil { + return nil, err + } + return dbTask, nil +} + +func (c *client) TransStatus(ctx context.Context, taskID uint32, event trans.TaskEvent) error { + return c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return c.transStatus(ctx, tx, taskID, event) + }) +} + +func (c *client) CheckStop(ctx context.Context, taskID uint32) (bool, error) { + dbTask := &model.AsyncTask{} + if err := c.db.WithContext(ctx).Where("id = ?", taskID).First(dbTask).Error; err != nil { + return false, err + } + return dbTask.State == trans.TaskStateRunning && (dbTask.Mark == trans.TaskMarkDelete || dbTask.Mark == trans.TaskMarkPause), nil +} + +func (c *client) UpdateHeartbeat(ctx context.Context, taskID uint32) error { + return c.db.WithContext(ctx).Model(&model.AsyncTask{}).Where("id = ?", taskID).Updates(map[string]interface{}{ + "updated_at": time.Now().UnixMilli(), + }).Error +} + +func (c *client) UpdateContext(ctx context.Context, taskID uint32, taskCtx string) error { + return c.db.WithContext(ctx).Model(&model.AsyncTask{}).Where("id = ?", taskID).Updates(map[string]interface{}{ + "ctx": taskCtx, + }).Error +} + +func (c *client) Delete(ctx context.Context, taskID uint32) error { + return c.db.WithContext(ctx).Unscoped().Where("id = ?", taskID).Delete(&model.AsyncTask{}).Error +} + +func (c *client) Clean(ctx context.Context, timeout int) error { + return c.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var dbTasks []*model.AsyncTask + if err := tx.Where("state IN ?", []trans.TaskState{trans.TaskStateRunning, trans.TaskStateDeleting}).Find(&dbTasks).Error; err != nil { + return err + } + if len(dbTasks) == 0 { + return nil + } + updateLimit := time.Now().Add(-time.Duration(timeout) * time.Second).UnixMilli() + for _, task := range dbTasks { + if task.UpdatedAt < updateLimit { + if err := tx.Model(&model.AsyncTask{}).Where("id = ?", task.ID).Updates(map[string]interface{}{ + "state": trans.TaskStateFailed, + }).Error; err != nil { + return err + } + if task.State == trans.TaskStateRunning { + c.log.Errorf("async task %v running but cleaned, last updated at %v", task.ID, task.UpdatedAt) + } else { + c.log.Errorf("async task %v deleting but cleaned, last updated at %v", task.ID, task.UpdatedAt) + } + } + } + return nil + }) +} + +func (c *client) transStatus(ctx context.Context, tx *gorm.DB, taskID uint32, event trans.TaskEvent) error { + dbTask := &model.AsyncTask{} + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("id = ?", taskID).First(dbTask).Error; err != nil { + return err + } + // check transfer + state, del, err := trans.CheckTransfer(trans.TaskStatus{S: dbTask.State, M: dbTask.Mark}, event) + if err != nil { + return fmt.Errorf("async task %v %v", taskID, err.Error()) + } + // check out pending.runQueue + var isPending bool // 用于标记是否需要记录错误日志;pending queue是外部组件,执行错误事务中断可能引起内外数据不一致,需要人工介入 + if dbTask.State == trans.TaskStatePending && dbTask.Mark == trans.TaskMarkRun && event != trans.EventSysExecute { + if event != trans.EventSysExecute { // event是系统执行时,任务已经出队列了 + isPending = true + if err = c.pendingRun.DelTask(ctx, taskID); err != nil { + c.log.Errorf("async task %v {state %v, mark %v} event %v transfer {state %v, mark %v} out pending run queue err: %v", + taskID, dbTask.State, dbTask.Mark, event, state.S, state.M, err) + return err + } + } + } + // check in pending.runQueue + if state.S == trans.TaskStatePending && state.M == trans.TaskMarkRun { + isPending = true + if err = c.pendingRun.AddTask(ctx, dbTask.User, dbTask.Group, taskID, dbTask.Type); err != nil { + c.log.Errorf("async task %v {state %v, mark %v} event %v transfer {state %v, mark %v} in pending run queue err: %v", + taskID, dbTask.State, dbTask.Mark, event, state.S, state.M, err) + return err + } + } + // check out pending.delQueue + if dbTask.State == trans.TaskStatePending && dbTask.Mark == trans.TaskMarkDelete && event != trans.EventSysExecute { + if event != trans.EventSysExecute { // event是系统执行时,任务已经出队列了 + isPending = true + if err = c.pendingDel.DelTask(ctx, taskID); err != nil { + c.log.Errorf("async task %v {state %v, mark %v} event %v transfer {state %v, mark %v} out pending del queue err: %v", + taskID, dbTask.State, dbTask.Mark, event, state.S, state.M, err) + return err + } + } + } + // check in pending.delQueue + if state.S == trans.TaskStatePending && state.M == trans.TaskMarkDelete { + isPending = true + if err = c.pendingDel.AddTask(ctx, dbTask.User, dbTask.Group, taskID, dbTask.Type); err != nil { + c.log.Errorf("async task %v {state %v, mark %v} event %v transfer {state %v, mark %v} in pending del queue err: %v", + taskID, dbTask.State, dbTask.Mark, event, state.S, state.M, err) + return err + } + } + // del or update + if del { + if err = tx.Unscoped().Where("id = ?", taskID).Delete(&model.AsyncTask{}).Error; err != nil { + if isPending { + c.log.Errorf("async task %v {state %v, mark %v} event %v transfer {state %v, mark %v} del err: %v", + taskID, dbTask.State, dbTask.Mark, event, state.S, state.M, err) + } + return err + } + return nil + } + updates := map[string]interface{}{ + "state": state.S, + "mark": state.M, + } + if state.S == trans.TaskStateFinished || state.S == trans.TaskStateFailed { + updates["done_at"] = time.Now().UnixMilli() + } + if err = tx.Model(&model.AsyncTask{}).Where("id = ?", taskID).Updates(updates).Error; err != nil { + if isPending { + c.log.Errorf("async task %v {state %v, mark %v} event %v transfer {state %v, mark %v} update err: %v", + taskID, dbTask.State, dbTask.Mark, event, state.S, state.M, err) + } + return err + } + return nil +} diff --git a/async/internal/db/model/task.go b/async/internal/db/model/task.go new file mode 100644 index 000000000..3942c34d2 --- /dev/null +++ b/async/internal/db/model/task.go @@ -0,0 +1,28 @@ +package model + +import ( + "github.com/UnicomAI/wanwu/async/internal/db/trans" + "github.com/UnicomAI/wanwu/pkg/db" +) + +// AsyncTask 异步任务 DB Model +// target变化不会导致state变化 +type AsyncTask struct { + ID uint32 `gorm:"primary_key"` + CreatedAt int64 `gorm:"autoCreateTime:milli;not null"` + UpdatedAt int64 `gorm:"autoUpdateTime:milli;not null"` + // 用户 + User string `gorm:"index:idx_async_task_user"` + // 任务组 + Group string `gorm:"index:idx_async_task_group"` + // 任务类型 + Type uint32 `gorm:"index:idx_async_task_type;not null"` + // 状态 + State trans.TaskState `gorm:"index:idx_async_task_state;not null"` + // 标记 + Mark trans.TaskMark `gorm:"index:idx_async_task_mark;not null"` + // 结束时间戳(finished/failed) + DoneAt int64 `gorm:"not null"` + // 序列化上下文 + Ctx db.LongText `gorm:"not null"` +} diff --git a/async/internal/db/trans/trans.go b/async/internal/db/trans/trans.go new file mode 100644 index 000000000..9bec470bd --- /dev/null +++ b/async/internal/db/trans/trans.go @@ -0,0 +1,224 @@ +package trans + +import ( + "fmt" +) + +// TaskState 任务状态 +type TaskState int32 + +const ( + TaskStateInit TaskState = 0 // 已创建 + TaskStatePending TaskState = 1 // 排队中 + TaskStateRunning TaskState = 2 // 运行中 + TaskStateDeleting TaskState = 3 // 删除中 + TaskStatePause TaskState = 4 // 暂停 + TaskStateFinished TaskState = 5 // 结束 + TaskStateFailed TaskState = 6 // 失败 +) + +// TaskMark 任务标记,任务期望的操作 +type TaskMark int32 + +const ( + TaskMarkNone TaskMark = 0 // 无标记 + TaskMarkRun TaskMark = 1 // 标记需要开始 + TaskMarkDelete TaskMark = 2 // 标记需要删除 + TaskMarkPause TaskMark = 3 // 标记需要暂停 +) + +// TaskEvent 任务事件 +type TaskEvent int32 + +const ( + EventNone TaskEvent = 0 +) + +// 用户事件 +const ( + EventUserRun TaskEvent = 1 // 用户运行任务 + EventUserDelete TaskEvent = 2 // 用户删除任务 + EventUserPause TaskEvent = 3 // 用户暂停任务 +) + +// 系统事件 +const ( + EventSysExecute TaskEvent = 4 // 系统执行任务 + EventSysPause TaskEvent = 6 // 系统暂停任务 +) + +// 任务事件 +const ( + EventTaskFinished TaskEvent = 7 // 任务结束 + EventTaskFailed TaskEvent = 8 // 任务失败 +) + +// TaskStatus 任务{状态, 标记}组合 +type TaskStatus struct { + S TaskState + M TaskMark +} + +func CheckTransfer(taskStatus TaskStatus, event TaskEvent) (TaskStatus, bool, error) { + switch taskStatus { + + // 任务创建初始状态,不会排队、运行 + case TaskStatus{S: TaskStateInit, M: TaskMarkNone}: + switch event { + case EventUserRun: + // 进入pending.runQueue + return TaskStatus{S: TaskStatePending, M: TaskMarkRun}, false, nil + case EventUserDelete: + // 直接删除任务记录 + return TaskStatus{}, true, nil + case EventUserPause: + // 直接暂停 + return TaskStatus{S: TaskStatePause, M: TaskMarkPause}, false, nil + } + + // 任务在pending.runQueue中排队,用户可见为排队中 + case TaskStatus{S: TaskStatePending, M: TaskMarkRun}: + switch event { + case EventUserDelete: + // 离开pending.runQueue,进入pending.deleteQueue + return TaskStatus{S: TaskStatePending, M: TaskMarkDelete}, false, nil + case EventUserPause: + // 离开pending.runQueue,直接暂停 + return TaskStatus{S: TaskStatePause, M: TaskMarkPause}, false, nil + case EventSysExecute: + // 离开pending.runQueue,进入内存执行run + return TaskStatus{S: TaskStateRunning, M: TaskMarkRun}, false, nil + default: + } + + // 任务在pending.deleteQueue中排队,用户不可见 + case TaskStatus{S: TaskStatePending, M: TaskMarkDelete}: + switch event { + case EventSysExecute: + // 离开pending.deleteQueue,进入内存执行delete + return TaskStatus{S: TaskStateDeleting, M: TaskMarkDelete}, false, nil + default: + } + + // 任务在内存中执行run,用户可见为运行中 + case TaskStatus{S: TaskStateRunning, M: TaskMarkRun}: + switch event { + case EventUserDelete: + // 用户标记删除 + return TaskStatus{S: TaskStateRunning, M: TaskMarkDelete}, false, nil + case EventUserPause: + // 用户标记暂停 + return TaskStatus{S: TaskStateRunning, M: TaskMarkPause}, false, nil + case EventSysPause: + // 停止执行,离开内存 + return TaskStatus{S: TaskStatePause, M: TaskMarkRun}, false, nil + case EventTaskFinished: + // 执行结束,离开内存 + return TaskStatus{S: TaskStateFinished, M: TaskMarkRun}, false, nil + case EventTaskFailed: + // 执行失败,离开内存 + return TaskStatus{S: TaskStateFailed, M: TaskMarkRun}, false, nil + default: + } + + // 任务在内存中执行run,但用户标记删除,用户可见为取消中 + case TaskStatus{S: TaskStateRunning, M: TaskMarkDelete}: + switch event { + case EventSysPause, EventTaskFinished, EventTaskFailed: + // 停止执行/执行结束,离开内存,进入pending.deleteQueue + return TaskStatus{S: TaskStatePending, M: TaskMarkDelete}, false, nil + default: + } + + // 任务在内存中执行run,用户标记暂停,用户可见为取消中 + case TaskStatus{S: TaskStateRunning, M: TaskMarkPause}: + switch event { + case EventSysPause: + // 停止执行,离开内存 + return TaskStatus{S: TaskStatePause, M: TaskMarkPause}, false, nil + case EventTaskFinished: + // 执行结束,离开内存 + return TaskStatus{S: TaskStateFinished, M: TaskMarkRun}, false, nil + case EventTaskFailed: + // 执行失败,离开内存 + return TaskStatus{S: TaskStateFailed, M: TaskMarkRun}, false, nil + default: + } + + // 任务在内存中执行delete,用户不可见 + case TaskStatus{S: TaskStateDeleting, M: TaskMarkDelete}: + switch event { + case EventSysPause: + // 停止执行,离开内存 + return TaskStatus{S: TaskStatePause, M: TaskMarkDelete}, false, nil + case EventTaskFinished: + // 执行结束,直接删除任务记录 + return TaskStatus{}, true, nil + case EventTaskFailed: + // 执行结束,离开内存 + return TaskStatus{S: TaskStateFailed, M: TaskMarkDelete}, false, nil + } + + // 任务暂停run,但用户可见为运行中 + case TaskStatus{S: TaskStatePause, M: TaskMarkRun}: + switch event { + case EventUserPause: + // 直接暂停 + return TaskStatus{S: TaskStatePause, M: TaskMarkPause}, false, nil + case EventUserDelete: + // 进入pending.deleteQueue + return TaskStatus{S: TaskStatePending, M: TaskMarkDelete}, false, nil + case EventSysExecute: + // 进入内存执行run + return TaskStatus{S: TaskStateRunning, M: TaskMarkRun}, false, nil + default: + } + + // 任务暂停delete,用户不可见 + case TaskStatus{S: TaskStatePause, M: TaskMarkDelete}: + switch event { + case EventSysExecute: + // 进入内存执行delete + return TaskStatus{S: TaskStateDeleting, M: TaskMarkDelete}, false, nil + default: + } + + // 任务暂停run,用户可见为暂停中 + case TaskStatus{S: TaskStatePause, M: TaskMarkPause}: + switch event { + case EventUserRun: + // 进入pending.runQueue + return TaskStatus{S: TaskStatePending, M: TaskMarkRun}, false, nil + case EventUserDelete: + // 进入pending.deleteQueue + return TaskStatus{S: TaskStatePending, M: TaskMarkDelete}, false, nil + default: + } + + // 任务执行run结束,用户可见为结束 + case TaskStatus{S: TaskStateFinished, M: TaskMarkRun}: + switch event { + case EventUserDelete: + // 进入pending.deleteQueue + return TaskStatus{S: TaskStatePending, M: TaskMarkDelete}, false, nil + default: + } + + // 任务执行run失败,用户可见为失败 + case TaskStatus{S: TaskStateFailed, M: TaskMarkRun}: + switch event { + case EventUserRun: + // 失败后重启,进入pending.runQueue + return TaskStatus{S: TaskStatePending, M: TaskMarkRun}, false, nil + case EventUserDelete: + // 进入pending.deleteQueue + return TaskStatus{S: TaskStatePending, M: TaskMarkDelete}, false, nil + default: + } + + default: + } + + return TaskStatus{}, false, fmt.Errorf("{state %v, mark %v} event %v transfer invalid", + taskStatus.S, taskStatus.M, event) +} diff --git a/async/internal/tools/log.go b/async/internal/tools/log.go new file mode 100644 index 000000000..500371573 --- /dev/null +++ b/async/internal/tools/log.go @@ -0,0 +1,49 @@ +package tools + +import ( + "log" + + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" +) + +func DefaultLog() async_config.Logger { + return &defaultLog{} +} + +func EmptyLog() async_config.Logger { return &emptyLog{} } + +// --- default log --- + +type defaultLog struct{} + +func (l *defaultLog) Debugf(fmt string, i ...interface{}) { + log.Printf("[ASYNC][DEBUG] "+fmt, i...) +} + +func (l *defaultLog) Infof(fmt string, i ...interface{}) { + log.Printf("[ASYNC][INFO] "+fmt, i...) +} + +func (l *defaultLog) Warnf(fmt string, i ...interface{}) { + log.Printf("[ASYNC][WARN] "+fmt, i...) +} + +func (l *defaultLog) Errorf(fmt string, i ...interface{}) { + log.Printf("[ASYNC][ERROR] "+fmt, i...) +} + +// --- empty log --- + +type emptyLog struct{} + +func (l *emptyLog) Debugf(fmt string, i ...interface{}) { +} + +func (l *emptyLog) Infof(fmt string, i ...interface{}) { +} + +func (l *emptyLog) Warnf(fmt string, i ...interface{}) { +} + +func (l *emptyLog) Errorf(fmt string, i ...interface{}) { +} diff --git a/async/internal/tools/recover.go b/async/internal/tools/recover.go new file mode 100644 index 000000000..c1f1369b9 --- /dev/null +++ b/async/internal/tools/recover.go @@ -0,0 +1,22 @@ +package tools + +import ( + "log" + "runtime" + "strings" +) + +var ( + panicLogLen = 2048 +) + +// PrintPanicStack recover并打印堆栈 +// 用法:defer tools.PrintPanicStack(),注意 defer func() { tools.PrintPanicStack() } 是无效的 +func PrintPanicStack() { + if r := recover(); r != nil { + buf := make([]byte, panicLogLen) + l := runtime.Stack(buf, false) + str := strings.ReplaceAll(string(buf[:l]), "\n", " ") + log.Printf("[PANIC] %v: %s", r, str) + } +} diff --git a/async/pkg/async/async_component/pending/del_default.go b/async/pkg/async/async_component/pending/del_default.go new file mode 100644 index 000000000..8bafffb48 --- /dev/null +++ b/async/pkg/async/async_component/pending/del_default.go @@ -0,0 +1,45 @@ +package pending + +import ( + "context" + + "gorm.io/gorm" + + "github.com/UnicomAI/wanwu/async/internal/db/model" + "github.com/UnicomAI/wanwu/async/internal/db/trans" + "github.com/UnicomAI/wanwu/async/pkg/async/async_component" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" +) + +func NewPendingDelDefault(db *gorm.DB, log async_config.Logger) async_component.IQueue { + return &pendingDelDefault{ + log: log, + db: db, + } +} + +type pendingDelDefault struct { + log async_config.Logger + db *gorm.DB +} + +func (d *pendingDelDefault) AddTask(ctx context.Context, user, group string, taskID, taskType uint32) error { + return nil +} + +func (d *pendingDelDefault) DelTask(ctx context.Context, taskID uint32) error { + return nil +} + +func (d *pendingDelDefault) PopOne(ctx context.Context, taskTypes []uint32) (uint32, error) { + var dbTasks []*model.AsyncTask + if err := d.db.WithContext(ctx).Where("state = ? AND mark = ?", trans.TaskStatePending, trans.TaskMarkDelete). + Where("type IN ?", taskTypes). + Order("updated_at").Limit(1).Find(&dbTasks).Error; err != nil { + return 0, err + } else if len(dbTasks) == 0 { + return 0, nil + } else { + return dbTasks[0].ID, nil + } +} diff --git a/async/pkg/async/async_component/pending/del_queue.go b/async/pkg/async/async_component/pending/del_queue.go new file mode 100644 index 000000000..c00e447b6 --- /dev/null +++ b/async/pkg/async/async_component/pending/del_queue.go @@ -0,0 +1,83 @@ +package pending + +import ( + "context" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/UnicomAI/wanwu/async/pkg/async/async_component" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" +) + +func NewPendingDel(db *gorm.DB, log async_config.Logger) (async_component.IQueue, error) { + if err := db.AutoMigrate( + asyncPendingDelTask{}, + ); err != nil { + return nil, err + } + return &pendingDel{ + log: log, + db: db, + }, nil +} + +type pendingDel struct { + log async_config.Logger + db *gorm.DB +} + +type asyncPendingDelTask struct { + ID uint32 `gorm:"primary_key"` + CreatedAt int64 `gorm:"autoCreateTime:milli;not null"` + UpdatedAt int64 `gorm:"autoUpdateTime:milli;not null"` + // 用户 + User string `gorm:"index:idx_async_pending_del_task_user"` + // 任务组 + Group string `gorm:"index:idx_async_pending_del_task_group"` + // 任务ID + TaskID uint32 `gorm:"index:idx_async_pending_del_task_task_id;not null"` + // 任务类型 + Type uint32 `gorm:"index:idx_async_pending_del_task_type;not null"` +} + +func (d *pendingDel) AddTask(ctx context.Context, user, group string, taskID, taskType uint32) error { + return d.db.WithContext(ctx).Create(&asyncPendingDelTask{ + User: user, + Group: group, + TaskID: taskID, + Type: taskType, + }).Error +} + +func (d *pendingDel) DelTask(ctx context.Context, taskID uint32) error { + return d.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // select + dbTask := &asyncPendingDelTask{} + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("task_id = ?", taskID).First(dbTask).Error; err != nil { + return err + } + // delete + return tx.Unscoped().Where("task_id = ?", dbTask.TaskID).Delete(&asyncPendingDelTask{}).Error + }) +} + +func (d *pendingDel) PopOne(ctx context.Context, taskTypes []uint32) (uint32, error) { + var taskID uint32 + if err := d.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var dbTasks []*asyncPendingDelTask + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Where("type IN ?", taskTypes). + Order("id").Limit(1).Find(&dbTasks).Error; err != nil { + return err + } else if len(dbTasks) == 0 { + return nil + } else { + taskID = dbTasks[0].TaskID + } + return tx.Unscoped().Where("task_id = ?", taskID).Delete(&asyncPendingDelTask{}).Error + }); err != nil { + return 0, err + } + return taskID, nil +} diff --git a/async/pkg/async/async_component/pending/run_default.go b/async/pkg/async/async_component/pending/run_default.go new file mode 100644 index 000000000..d318c999c --- /dev/null +++ b/async/pkg/async/async_component/pending/run_default.go @@ -0,0 +1,45 @@ +package pending + +import ( + "context" + + "gorm.io/gorm" + + "github.com/UnicomAI/wanwu/async/internal/db/model" + "github.com/UnicomAI/wanwu/async/internal/db/trans" + "github.com/UnicomAI/wanwu/async/pkg/async/async_component" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" +) + +func NewPendingRunDefault(db *gorm.DB, log async_config.Logger) async_component.IQueue { + return &pendingRunDefault{ + log: log, + db: db, + } +} + +type pendingRunDefault struct { + log async_config.Logger + db *gorm.DB +} + +func (r *pendingRunDefault) AddTask(ctx context.Context, user, group string, taskID, taskType uint32) error { + return nil +} + +func (r *pendingRunDefault) DelTask(ctx context.Context, taskID uint32) error { + return nil +} + +func (r *pendingRunDefault) PopOne(ctx context.Context, taskTypes []uint32) (uint32, error) { + var dbTasks []*model.AsyncTask + if err := r.db.WithContext(ctx).Where("state = ? AND mark = ?", trans.TaskStatePending, trans.TaskMarkRun). + Where("type IN ?", taskTypes). + Order("updated_at").Limit(1).Find(&dbTasks).Error; err != nil { + return 0, err + } else if len(dbTasks) == 0 { + return 0, nil + } else { + return dbTasks[0].ID, nil + } +} diff --git a/async/pkg/async/async_component/pending/run_queue.go b/async/pkg/async/async_component/pending/run_queue.go new file mode 100644 index 000000000..16c7a73fe --- /dev/null +++ b/async/pkg/async/async_component/pending/run_queue.go @@ -0,0 +1,83 @@ +package pending + +import ( + "context" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/UnicomAI/wanwu/async/pkg/async/async_component" + "github.com/UnicomAI/wanwu/async/pkg/async/async_config" +) + +func NewPendingRun(db *gorm.DB, log async_config.Logger) (async_component.IQueue, error) { + if err := db.AutoMigrate( + asyncPendingRunTask{}, + ); err != nil { + return nil, err + } + return &pendingRun{ + log: log, + db: db, + }, nil +} + +type pendingRun struct { + log async_config.Logger + db *gorm.DB +} + +type asyncPendingRunTask struct { + ID uint32 `gorm:"primary_key"` + CreatedAt int64 `gorm:"autoCreateTime:milli;not null"` + UpdatedAt int64 `gorm:"autoUpdateTime:milli;not null"` + // 用户 + User string `gorm:"index:idx_async_pending_run_task_user"` + // 任务组 + Group string `gorm:"index:idx_async_pending_run_task_group"` + // 任务ID + TaskID uint32 `gorm:"index:idx_async_pending_run_task_task_id;not null"` + // 任务类型 + Type uint32 `gorm:"index:idx_async_pending_run_task_type;not null"` +} + +func (r *pendingRun) AddTask(ctx context.Context, user, group string, taskID, taskType uint32) error { + return r.db.WithContext(ctx).Create(&asyncPendingRunTask{ + User: user, + Group: group, + TaskID: taskID, + Type: taskType, + }).Error +} + +func (r *pendingRun) DelTask(ctx context.Context, taskID uint32) error { + return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + // select + dbTask := &asyncPendingRunTask{} + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Where("task_id = ?", taskID).First(dbTask).Error; err != nil { + return err + } + // delete + return tx.Unscoped().Where("task_id = ?", dbTask.TaskID).Delete(&asyncPendingRunTask{}).Error + }) +} + +func (r *pendingRun) PopOne(ctx context.Context, taskTypes []uint32) (uint32, error) { + var taskID uint32 + if err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + var dbTasks []*asyncPendingRunTask + if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Where("type IN ?", taskTypes). + Order("id").Limit(1).Find(&dbTasks).Error; err != nil { + return err + } else if len(dbTasks) == 0 { + return nil + } else { + taskID = dbTasks[0].TaskID + } + return tx.Unscoped().Where("task_id = ?", taskID).Delete(&asyncPendingRunTask{}).Error + }); err != nil { + return 0, err + } + return taskID, nil +} diff --git a/async/pkg/async/async_component/pending_queue.go b/async/pkg/async/async_component/pending_queue.go new file mode 100644 index 000000000..d063ada90 --- /dev/null +++ b/async/pkg/async/async_component/pending_queue.go @@ -0,0 +1,12 @@ +package async_component + +import "context" + +type IQueue interface { + // AddTask 向队列中添加一个任务 + AddTask(ctx context.Context, user, group string, taskID, taskType uint32) error + // DelTask 从队列中删除一个任务(直接删除,只有正确删除指定任务才不返回error) + DelTask(ctx context.Context, taskID uint32) error + // PopOne 从队列中获取一个任务(出队列) + PopOne(ctx context.Context, taskTypes []uint32) (uint32, error) +} diff --git a/async/pkg/async/async_config/log.go b/async/pkg/async/async_config/log.go new file mode 100644 index 000000000..97a1cc5f4 --- /dev/null +++ b/async/pkg/async/async_config/log.go @@ -0,0 +1,8 @@ +package async_config + +type Logger interface { + Debugf(fmt string, i ...interface{}) + Infof(fmt string, i ...interface{}) + Warnf(fmt string, i ...interface{}) + Errorf(fmt string, i ...interface{}) +} diff --git a/async/pkg/async/async_task/task.go b/async/pkg/async/async_task/task.go new file mode 100644 index 000000000..45e675efc --- /dev/null +++ b/async/pkg/async/async_task/task.go @@ -0,0 +1,57 @@ +package async_task + +import ( + "context" +) + +// RunPhase 任务运行阶段 +type RunPhase int32 + +const ( + RunPhaseFailed RunPhase = -1 // 失败 + RunPhaseNormal RunPhase = 1 // 正常 + RunPhaseFinished RunPhase = 2 // 完成 +) + +// IReport 任务上报interface,注意goroutine safe +type IReport interface { + // Phase 上报运行阶段,bool表示finished/failed是否删除任务记录 + Phase() (RunPhase, bool) + // Context 上报上下文 + Context() string +} + +// ITask 异步任务interface +type ITask interface { + // Running 任务执行,任务接收到stop消息后,须主动Report,之后退出即可 + Running(ctx context.Context, taskCtx string, stop <-chan struct{}) <-chan IReport + // Deleting 删除任务 + Deleting(ctx context.Context, taskCtx string, stop <-chan struct{}) <-chan IReport +} + +// ITaskFunc 异步任务初始化方法 +type ITaskFunc func() ITask + +// State 任务状态 +type State int32 + +const ( + StateInit State = 0 // 已创建 + StatePending State = 1 // 排队中 + StateRunning State = 2 // 运行中 + StateCanceling State = 3 // 取消中 + StatePause State = 4 // 暂停 + StateFinished State = 5 // 结束 + StateFailed State = 6 // 失败 +) + +type Task struct { + ID uint32 + User string + Group string + Type uint32 + State State + CreatedAt int64 + DoneAt int64 + Ctx string +} diff --git a/cmd/assistant-service/main.go b/cmd/assistant-service/main.go index 485314c87..20cec2b99 100644 --- a/cmd/assistant-service/main.go +++ b/cmd/assistant-service/main.go @@ -13,7 +13,6 @@ import ( "github.com/UnicomAI/wanwu/internal/assistant-service/config" "github.com/UnicomAI/wanwu/internal/assistant-service/server/grpc" "github.com/UnicomAI/wanwu/pkg/db" - "github.com/UnicomAI/wanwu/pkg/es" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/minio" mp "github.com/UnicomAI/wanwu/pkg/model-provider" @@ -56,14 +55,6 @@ func main() { log.Fatalf("init redis err: %v", err) } - if err := es.InitAssistant(ctx, config.Cfg().ES); err != nil { - log.Fatalf("init es err: %v", err) - } - - if err := es.InitESIndexTemplate(ctx); err != nil { - log.Fatalf("init es index template err: %v", err) - } - if err := minio.InitAssistant(ctx, minio.Config{ Endpoint: config.Cfg().Minio.EndPoint, User: config.Cfg().Minio.User, @@ -98,7 +89,6 @@ func main() { <-sc s.Stop() redis.StopSys() - es.StopAssistant() } func versionPrint() { diff --git a/configs/microservice/knowledge-service/configs/config.yaml b/configs/microservice/knowledge-service/configs/config.yaml index fadd5d528..3678930a9 100644 --- a/configs/microservice/knowledge-service/configs/config.yaml +++ b/configs/microservice/knowledge-service/configs/config.yaml @@ -83,16 +83,36 @@ minio: knowledge-export-dir: knowledge qa-export-dir: qa -kafka: - addr: kafka-wanwu:9092 - user: "admin" - password: +topic: topic: "doc-rag" knowledge-graph-topic: "rag-knowledge-graph" url-analysis-topic: "url-batch-a-prod" url-import-topic: "url-batch-i-prod" + +kafka: + enabled: false + addr: kafka-wanwu:9092 + user: "admin" + password: default-partition-num: 3 +redis: + enabled: false + mode: standalone + addr: + - redis-wanwu:6379 + password: + db: 0 + pool-size: 10 + min-idle-conns: 3 + max-retries: 3 + dial-timeout: 3 + read-timeout: 3 + write-timeout: 3 + idle-timeout: 10 + stream-max-len: 1000 + stream-approx-max-len: true + rag-server: endpoint: 'http://rag-wanwu:8681' proxy-point: 'http://bff-service:6668' diff --git a/docker-compose.yaml b/docker-compose.yaml index b35c7ae2d..94a503b06 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -26,21 +26,6 @@ services: retries: 99 start_period: 10s - mysql-setup: - depends_on: - mysql: - condition: service_healthy - restart: on-failure - image: ${WANWU_MYSQL_IMAGE} - container_name: ${WANWU_MYSQL_HOST}-setup - networks: - - ${WANWU_DOCKER_NETWORK} - volumes: - - ./configs/middleware/mysql/initdb.d:/docker-entrypoint-initdb.d - entrypoint: "bash -c" - command: - - "exec mysql --host mysql -u root -p'${WANWU_MYSQL_PASSWORD}' < /docker-entrypoint-initdb.d/init.sql" - redis: restart: always image: ${WANWU_REDIS_IMAGE} @@ -81,158 +66,6 @@ services: retries: 99 start_period: 10s - kafka: - restart: always - image: ${WANWU_KAFKA_IMAGE} - container_name: ${WANWU_KAFKA_HOST} - networks: - - ${WANWU_DOCKER_NETWORK} - # ports: - # - 9092:9092 - volumes: - - ./configs/middleware/kafka/configs/kafka-log4j.properties:/opt/bitnami/kafka/config/log4j.properties - - ${WANWU_PROJECT_DIR}/kafka/logs:/opt/bitnami/kafka/logs - - wanwu_kafka_data:/bitnami/kafka/data - environment: - # kraft - KAFKA_CFG_NODE_ID: 1 - KAFKA_CFG_PROCESS_ROLES: controller,broker - KAFKA_CFG_CONTROLLER_QUORUM_VOTERS: 1@${WANWU_KAFKA_HOST}:9093 - # listener - KAFKA_CFG_LISTENER_SECURITY_PROTOCOL_MAP: BROKER:SASL_PLAINTEXT,CONTROLLER:SASL_PLAINTEXT - KAFKA_CFG_LISTENERS: BROKER://0.0.0.0:9092,CONTROLLER://0.0.0.0:9093 - KAFKA_CFG_ADVERTISED_LISTENERS: BROKER://${WANWU_KAFKA_HOST}:9092,CONTROLLER://${WANWU_KAFKA_HOST}:9093 - KAFKA_CLIENT_USERS: ${WANWU_KAFKA_USER} - KAFKA_CLIENT_PASSWORD: ${WANWU_KAFKA_PASSWORD} - # broker - KAFKA_CFG_SASL_MECHANISM_INTER_BROKER_PROTOCOL: PLAIN - KAFKA_CFG_INTER_BROKER_LISTENER_NAME: BROKER - KAFKA_CFG_BROKER_ID: 1 - KAFKA_INTER_BROKER_USER: ${WANWU_KAFKA_USER} - KAFKA_INTER_BROKER_PASSWORD: ${WANWU_KAFKA_PASSWORD} - # controller - KAFKA_CFG_SASL_MECHANISM_CONTROLLER_PROTOCOL: PLAIN - KAFKA_CFG_CONTROLLER_LISTENER_NAMES: CONTROLLER - KAFKA_CONTROLLER_USER: ${WANWU_KAFKA_USER} - KAFKA_CONTROLLER_PASSWORD: ${WANWU_KAFKA_PASSWORD} - # opt - KAFKA_CFG_OFFSETS_TOPIC_REPLICATION_FACTOR: 1 - KAFKA_CFG_GROUP_INITIAL_REBALANCE_DELAY_MS: 0 - KAFKA_CFG_ALLOW_EVERYONE_IF_NO_ACL_FOUND: "false" - KAFKA_CFG_HEAP_OPTS: -Xmx1G -Xms1G - healthcheck: - test: ["CMD-SHELL", "timeout 1 bash -c 'cat < /dev/null > /dev/tcp/localhost/9092' || exit 1"] - interval: 15s - timeout: 5s - retries: 99 - start_period: 10s - - es-setup: - restart: on-failure - image: ${WANWU_ELASTIC_IMAGE} - container_name: ${WANWU_ELASTIC_HOST}-setup - networks: - - ${WANWU_DOCKER_NETWORK} - volumes: - - wanwu_es_certs:/usr/share/elasticsearch/config/certs - user: "0" - command: > - bash -c ' - if [ x${WANWU_ELASTIC_PASSWORD} == x ]; then - echo "Set the WANWU_ELASTIC_PASSWORD environment variable in the .env file"; - exit 1; - elif [ x${WANWU_KIBANA_PASSWORD} == x ]; then - echo "Set the WANWU_KIBANA_PASSWORD environment variable in the .env file"; - exit 1; - fi; - if [ ! -f config/certs/ca.zip ]; then - echo "Creating CA"; - bin/elasticsearch-certutil ca --silent --pem -out config/certs/ca.zip; - unzip config/certs/ca.zip -d config/certs; - fi; - if [ ! -f config/certs/certs.zip ]; then - echo "Creating certs"; - echo -ne \ - "instances:\n"\ - " - name: es\n"\ - " dns:\n"\ - " - es\n"\ - " - localhost\n"\ - " ip:\n"\ - " - 127.0.0.1\n"\ - > config/certs/instances.yml; - bin/elasticsearch-certutil cert --silent --pem -out config/certs/certs.zip --in config/certs/instances.yml --ca-cert config/certs/ca/ca.crt --ca-key config/certs/ca/ca.key; - unzip config/certs/certs.zip -d config/certs; - fi; - echo "Setting file permissions" - chown -R root:root config/certs; - find . -type d -exec chmod 750 \{\} \;; - find . -type f -exec chmod 640 \{\} \;; - echo "Waiting for Elasticsearch availability"; - until curl -s --cacert config/certs/ca/ca.crt https://es:9200 | grep -q "missing authentication credentials"; do sleep 30; done; - echo "Setting kibana user ${WANWU_KIBANA_USERNAME} password"; - until curl -s -X POST --cacert config/certs/ca/ca.crt -u "elastic:${WANWU_ELASTIC_PASSWORD}" -H "Content-Type: application/json" https://es:9200/_security/user/${WANWU_KIBANA_USERNAME}/_password -d "{\"password\":\"${WANWU_KIBANA_PASSWORD}\"}" | grep -q "^{}"; do sleep 10; done; - echo "All done!"; - ' - healthcheck: - test: ["CMD-SHELL", "[ -f config/certs/es/es.crt ]"] - interval: 15s - timeout: 5s - retries: 99 - start_period: 10s - - es: - restart: always - image: ${WANWU_ELASTIC_IMAGE} - container_name: ${WANWU_ELASTIC_HOST} - networks: - - ${WANWU_DOCKER_NETWORK} - # ports: - # - 9200:9200 - volumes: - - ./configs/middleware/elastic/plugins:/usr/share/elasticsearch/plugins - - wanwu_es_certs:/usr/share/elasticsearch/config/certs - - wanwu_es_data:/usr/share/elasticsearch/data - - wanwu_es_logs:/usr/share/elasticsearch/logs - - wanwu_es_tmp:/tmp - environment: - - node.name=es - - cluster.name=wanwu-es-cluster - - cluster.initial_master_nodes=es - - discovery.seed_hosts=es - - bootstrap.memory_lock=true - - xpack.security.enabled=true - - xpack.security.http.ssl.enabled=true - - xpack.security.http.ssl.key=certs/es/es.key - - xpack.security.http.ssl.certificate=certs/es/es.crt - - xpack.security.http.ssl.certificate_authorities=certs/ca/ca.crt - - xpack.security.transport.ssl.enabled=true - - xpack.security.transport.ssl.key=certs/es/es.key - - xpack.security.transport.ssl.certificate=certs/es/es.crt - - xpack.security.transport.ssl.certificate_authorities=certs/ca/ca.crt - - xpack.security.transport.ssl.verification_mode=certificate - - xpack.license.self_generated.type=basic - - ELASTIC_PASSWORD=${WANWU_ELASTIC_PASSWORD} - - "ES_JAVA_OPTS=-Xms2g -Xmx2g" - mem_limit: 4294967296 # 4GB - ulimits: - memlock: - soft: -1 - hard: -1 - nofile: - soft: 65536 - hard: 65536 - healthcheck: - test: - [ - "CMD-SHELL", - "curl -s --cacert config/certs/ca/ca.crt https://localhost:9200 | grep -q 'missing authentication credentials'", - ] - interval: 15s - timeout: 5s - retries: 99 - start_period: 15s - # ------ microservice ------ bff-service: @@ -406,8 +239,6 @@ services: condition: service_healthy minio: condition: service_healthy - kafka: - condition: service_healthy restart: always image: ${WANWU_BACKEND_IMAGE} container_name: knowledge-service @@ -427,9 +258,8 @@ services: DB_OCEANBASE_USER: ${WANWU_OCEAN_BASE_USER}@${WANWU_OCEAN_BASE_TENANT_NAME} DB_OCEANBASE_ADDRESS: ${WANWU_OCEAN_BASE_HOST}:${WANWU_OCEAN_BASE_PORT} DB_OCEANBASE_PASSWORD: ${WANWU_OCEAN_BASE_PASSWORD} - KAFKA_ADDR: ${WANWU_KAFKA_ADDRESS} - KAFKA_USER: ${WANWU_KAFKA_USER} - KAFKA_PASSWORD: ${WANWU_KAFKA_PASSWORD} + REDIS_ENABLED: ${WANWU_REDIS_ENABLED} + REDIS_PASSWORD: ${WANWU_REDIS_PASSWORD} MINIO_ENDPOINT: ${WANWU_MINIO_ENDPOINT} MINIO_USER: ${WANWU_MINIO_USER} MINIO_PASSWORD: ${WANWU_MINIO_PASSWORD} @@ -484,8 +314,6 @@ services: condition: service_healthy redis: condition: service_healthy - es: - condition: service_healthy restart: always image: ${WANWU_BACKEND_IMAGE} container_name: assistant-service @@ -511,9 +339,6 @@ services: MINIO_USER: ${WANWU_MINIO_USER} MINIO_PASSWORD: ${WANWU_MINIO_PASSWORD} MINIO_DOWNLOAD_URL: ${WANWU_WEB_BASE_URL}/minio/download/api/ - ES_ADDRESS: ${WANWU_ELASTIC_ADDRESS} - ES_USERNAME: ${WANWU_ELASTIC_USER} - ES_PASSWORD: ${WANWU_ELASTIC_PASSWORD} working_dir: /app command: ./bin/assistant-service healthcheck: @@ -716,114 +541,39 @@ services: working_dir: /app command: ./bin/workflow-wanwu - rag: + agent: depends_on: minio: condition: service_healthy - kafka: - condition: service_healthy - es: - condition: service_healthy restart: always - image: ${WANWU_RAG_IMAGE} - container_name: rag-wanwu + image: ${WANWU_AGENT_IMAGE} + container_name: agent-wanwu networks: - ${WANWU_DOCKER_NETWORK} # ports: - # - 8613:8613 - # - 8681:8681 - # - 10891:10891 - # - 15000:15000 + # - 7258:7258 + # - 1991:1991 + # - 1992:1992 + # - 15001:15001 + # - 15002:15002 + # - 15003:15003 volumes: - - ${WANWU_PROJECT_DIR}/rag-wanwu/rag_core/logs:/model_extend/rag_core/logs - - ${WANWU_PROJECT_DIR}/rag-wanwu/rag_core/graph/logs:/model_extend/rag_core/graph/logs - - ${WANWU_PROJECT_DIR}/rag-wanwu/rag_core/graph/data:/model_extend/rag_core/graph/data - - ${WANWU_PROJECT_DIR}/rag-wanwu/rag_es_server_unify/logs:/model_extend/rag_es_server_unify/logs + - ${WANWU_PROJECT_DIR}/agent-wanwu/agent_open_source/logs:/agent/agent_open_source/logs + - ${WANWU_PROJECT_DIR}/agent-wanwu/agent_open_source/minio/logs:/agent/agent_open_source/minio/logs + - ${WANWU_PROJECT_DIR}/agent-wanwu/agent_open_source/minio/file:/agent/agent_open_source/minio/file + - ${WANWU_PROJECT_DIR}/agent-wanwu/agent_open_source/agent_plugin/logs:/agent/agent_open_source/agent_plugin/logs environment: MINIO_ADDRESS: ${WANWU_MINIO_ENDPOINT} MINIO_ACCESS_KEY: ${WANWU_MINIO_USER} MINIO_SECRET_KEY: ${WANWU_MINIO_PASSWORD} - KAFKA_BOOTSTRAP_SERVERS: ${WANWU_KAFKA_ADDRESS} - KAFKA_SASL_PLAIN_USERNAME: ${WANWU_KAFKA_USER} - KAFKA_SASL_PLAIN_PASSWORD: ${WANWU_KAFKA_PASSWORD} - ES_HOSTS: https://${WANWU_ELASTIC_ADDRESS} - ES_USER: ${WANWU_ELASTIC_USER} - ES_PASSWORD: ${WANWU_ELASTIC_PASSWORD} - REDIS_HOST: ${WANWU_REDIS_HOST} - REDIS_PORT: ${WANWU_REDIS_PORT} - REDIS_PASSWD: ${WANWU_REDIS_PASSWORD} - # 用于rag调用外部API的env - GET_KB_ID_URL: http://bff-service:6668/v1/api/category/info - KAFKA_MQ_REL_URL: http://bff-service:6668/v1/api/docstatus - KAFKA_MQ_KB_STATUS_URL: http://bff-service:6668/v1/api/knowledge/status - KAFKA_DOC_STATUS_INIT_URL: http://bff-service:6668/v1/api/doc_status_init - MODEL_PROVIDER_URL: http://bff-service:6668 - REPLACE_MINIO_DOWNLOAD_URL: ${WANWU_WEB_BASE_URL}/minio/download/api/ - - # agent: - # depends_on: - # minio: - # condition: service_healthy - # restart: always - # image: ${WANWU_AGENT_IMAGE} - # container_name: agent-wanwu - # networks: - # - ${WANWU_DOCKER_NETWORK} - # # ports: - # # - 7258:7258 - # # - 1991:1991 - # # - 1992:1992 - # # - 15001:15001 - # # - 15002:15002 - # # - 15003:15003 - # volumes: - # - ${WANWU_PROJECT_DIR}/agent-wanwu/agent_open_source/logs:/agent/agent_open_source/logs - # - ${WANWU_PROJECT_DIR}/agent-wanwu/agent_open_source/minio/logs:/agent/agent_open_source/minio/logs - # - ${WANWU_PROJECT_DIR}/agent-wanwu/agent_open_source/minio/file:/agent/agent_open_source/minio/file - # - ${WANWU_PROJECT_DIR}/agent-wanwu/agent_open_source/agent_plugin/logs:/agent/agent_open_source/agent_plugin/logs - # environment: - # MINIO_ADDRESS: ${WANWU_MINIO_ENDPOINT} - # MINIO_ACCESS_KEY: ${WANWU_MINIO_USER} - # MINIO_SECRET_KEY: ${WANWU_MINIO_PASSWORD} - # # 用于agent并发处理 - # WORKERS: 4 - # THREADS: 8 - # # 用于agent调用外部接口的API前缀 - # URL_MODEL: http://bff-service - # URL_MINIO: http://bff-service - # URL_RAG: http://rag-wanwu - # URL_RAG_STREAM: http://bff-service:6668/callback/v1/rag/knowledge/stream/search - - # agentscope: - # depends_on: - # mysql: - # condition: service_healthy - # redis: - # condition: service_healthy - # restart: always - # image: ${WANWU_AGENTSCOPE_IMAGE} - # container_name: agentscope-wanwu - # networks: - # - ${WANWU_DOCKER_NETWORK} - # ports: - # - 6672:6672 - # volumes: - # - ./configs/microservice/agentscope/configs/config.yaml:/agentscope/src/agentscope/aibigmodel_workflow/config.yaml - # - ${WANWU_PROJECT_DIR}/agentscope/logs:/agentscope/logs - # environment: - # DB_NAME: ${WANWU_DB_NAME} - # DATABASE_HOST: ${WANWU_MYSQL_HOST} - # DATABASE_PORT: ${WANWU_MYSQL_PORT} - # DATABASE_PASSWORD: ${WANWU_MYSQL_PASSWORD} - # REDIS_HOST: ${WANWU_REDIS_HOST} - # REDIS_PORT: ${WANWU_REDIS_PORT} - # REDIS_PASSWORD: ${WANWU_REDIS_PASSWORD} - # healthcheck: - # test: ["CMD", "python", "-c", "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); s.settimeout(5); result = s.connect_ex(('localhost', 6672)); s.close(); exit(0 if result == 0 else 1)"] - # interval: 15s - # timeout: 5s - # retries: 99 - # start_period: 10s + # 用于agent并发处理 + WORKERS: 4 + THREADS: 8 + # 用于agent调用外部接口的API前缀 + URL_MODEL: http://bff-service + URL_MINIO: http://bff-service + URL_RAG: http://rag-wanwu + URL_RAG_STREAM: http://bff-service:6668/callback/v1/rag/knowledge/stream/search # ------ nginx & frontend ------ @@ -846,15 +596,9 @@ services: networks: wanwu-net: - external: true volumes: wanwu_mysql_data: wanwu_redis_data: wanwu_minio_data: - wanwu_kafka_data: - wanwu_es_data: - wanwu_es_certs: - wanwu_es_logs: - wanwu_es_tmp: - \ No newline at end of file + diff --git a/go.mod b/go.mod index 539f3926f..511ec9ee1 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,6 @@ require ( github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/disintegration/imaging v1.6.2 github.com/eino-contrib/jsonschema v1.0.2 - github.com/elastic/go-elasticsearch/v8 v8.18.0 github.com/getkin/kin-openapi v0.118.0 github.com/gin-gonic/gin v1.10.1 github.com/go-ego/riot v0.0.0-20201013133145-f4c30acb3704 @@ -26,7 +25,6 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/gromitlee/access v1.1.0 github.com/gromitlee/depend/v2 v2.0.0 - github.com/gromitlee/go-async v1.0.2 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/hashicorp/go-version v1.8.0 github.com/mark3labs/mcp-go v0.43.0 diff --git a/internal/app-service/client/model/app_url.go b/internal/app-service/client/model/app_url.go index 92ffc0ea7..af12000cd 100644 --- a/internal/app-service/client/model/app_url.go +++ b/internal/app-service/client/model/app_url.go @@ -8,14 +8,14 @@ type AppUrl struct { CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` ExpiredAt int64 `gorm:"column:expired_at;comment:配置结束时间戳"` Copyright string `gorm:"column:copyright;type:text;comment:版权声明内容"` - CopyrightEnable bool `gorm:"column:copyright_enable;type:tinyint;comment:是否启用版权声明"` + CopyrightEnable bool `gorm:"column:copyright_enable;comment:是否启用版权声明"` PrivacyPolicy string `gorm:"column:privacy_policy;type:text;comment:隐私政策内容"` - PrivacyPolicyEnable bool `gorm:"column:privacy_policy_enable;type:tinyint;comment:是否启用隐私政策"` + PrivacyPolicyEnable bool `gorm:"column:privacy_policy_enable;comment:是否启用隐私政策"` Disclaimer string `gorm:"column:disclaimer;type:text;comment:免责声明内容"` - DisclaimerEnable bool `gorm:"column:disclaimer_enable;type:tinyint;comment:是否启用免责声明"` + DisclaimerEnable bool `gorm:"column:disclaimer_enable;comment:是否启用免责声明"` Suffix string `gorm:"column:suffix;type:varchar(255);comment:应用Url;index:idx_app_url_suffix"` UserId string `gorm:"column:user_id;index:idx_assistant_url_user_id;comment:用户Id;index:idx_app_url_user_id"` OrgId string `gorm:"column:org_id;index:idx_assistant_url_org_id;comment:组织Id;index:idx_app_url_org_id"` - Status bool `gorm:"column:status;type:tinyint;default:true;comment:应用Url开关;index:idx_app_url_status"` + Status bool `gorm:"column:status;default:true;comment:应用Url开关;index:idx_app_url_status"` Description string `gorm:"column:description;type:text;comment:app描述"` } diff --git a/internal/assistant-service/client/client.go b/internal/assistant-service/client/client.go index 26e77aa5b..8e996ef11 100644 --- a/internal/assistant-service/client/client.go +++ b/internal/assistant-service/client/client.go @@ -60,6 +60,10 @@ type IClient interface { GetConversationList(ctx context.Context, assistantID, userID, orgID string, offset, limit int32) ([]*model.Conversation, int64, *err_code.Status) DeleteConversationByAssistantID(ctx context.Context, assistantID, userID, orgID string) *err_code.Status + //============ConversationDetails============= + CreateConversationDetails(ctx context.Context, details *model.ConversationDetails) *err_code.Status + GetConversationDetailsList(ctx context.Context, conversationID, userID, orgID string, offset, limit int32) ([]*model.ConversationDetails, int64, *err_code.Status) + //================CustomPrompt================ CreateCustomPrompt(ctx context.Context, avatarPath, name, desc, prompt, userId, orgID string) (string, *err_code.Status) DeleteCustomPrompt(ctx context.Context, customPromptID uint32) *err_code.Status diff --git a/internal/assistant-service/client/model/assistant.go b/internal/assistant-service/client/model/assistant.go index 50448c8bd..93bb4818e 100644 --- a/internal/assistant-service/client/model/assistant.go +++ b/internal/assistant-service/client/model/assistant.go @@ -1,22 +1,24 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + type Assistant struct { - ID uint32 `gorm:"primarykey;column:id;comment:智能体Id"` - UUID string `gorm:"column:uuid;type:varchar(255);uniqueIndex:idx_unique_uuid;comment:智能体uuid"` - AvatarPath string `gorm:"column:avatar_path;comment:智能体头像"` - Name string `gorm:"column:name;comment:智能体名称"` - Desc string `gorm:"column:desc;comment:智能体介绍"` - Instructions string `gorm:"column:instructions;comment:系统提示词"` - Prologue string `gorm:"column:prologue;comment:开场白"` - RecommendQuestion string `gorm:"column:recommend_question;comment:推荐问题列表"` - ModelConfig string `gorm:"column:model_config;type:longtext;comment:模型配置"` - RerankConfig string `gorm:"column:rerank_config;type:longtext;comment:rerank模型配置"` - KnowledgebaseConfig string `gorm:"column:knowledgebase_config;type:longtext;comment:知识库配置"` - SafetyConfig string `gorm:"column:safety_config;type:longtext;comment:安全配置"` - VisionConfig string `gorm:"column:vision_config;type:longtext;comment:视觉配置"` - Scope int `gorm:"column:scope;type:tinyint;comment:智能体可见范围"` - UserId string `gorm:"column:user_id;index:idx_assistant_user_id;comment:用户id"` - OrgId string `gorm:"column:org_id;index:idx_assistant_org_id;comment:组织id"` - CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` - UpdatedAt int64 `gorm:"autoUpdateTime:milli;comment:更新时间"` + ID uint32 `gorm:"primarykey;column:id;comment:智能体Id"` + UUID string `gorm:"column:uuid;type:varchar(255);uniqueIndex:idx_unique_uuid;comment:智能体uuid"` + AvatarPath string `gorm:"column:avatar_path;comment:智能体头像"` + Name string `gorm:"column:name;comment:智能体名称"` + Desc string `gorm:"column:desc;comment:智能体介绍"` + Instructions string `gorm:"column:instructions;comment:系统提示词"` + Prologue string `gorm:"column:prologue;comment:开场白"` + RecommendQuestion string `gorm:"column:recommend_question;comment:推荐问题列表"` + ModelConfig db.LongText `gorm:"column:model_config;comment:模型配置"` + RerankConfig db.LongText `gorm:"column:rerank_config;comment:rerank模型配置"` + KnowledgebaseConfig db.LongText `gorm:"column:knowledgebase_config;comment:知识库配置"` + SafetyConfig db.LongText `gorm:"column:safety_config;comment:安全配置"` + VisionConfig db.LongText `gorm:"column:vision_config;comment:视觉配置"` + Scope int `gorm:"column:scope;comment:智能体可见范围"` + UserId string `gorm:"column:user_id;index:idx_assistant_user_id;comment:用户id"` + OrgId string `gorm:"column:org_id;index:idx_assistant_org_id;comment:组织id"` + CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` + UpdatedAt int64 `gorm:"autoUpdateTime:milli;comment:更新时间"` } diff --git a/internal/assistant-service/client/model/conversation_details.go b/internal/assistant-service/client/model/conversation_details.go index aaed690da..457c04595 100644 --- a/internal/assistant-service/client/model/conversation_details.go +++ b/internal/assistant-service/client/model/conversation_details.go @@ -1,26 +1,26 @@ package model type FileInfo struct { - FileName string `json:"fileName"` - FileSize int64 `json:"fileSize"` - FileUrl string `json:"fileUrl"` + FileName string `gorm:"type:varchar(500)" json:"fileName"` + FileSize int64 `gorm:"type:bigint" json:"fileSize"` + FileUrl string `gorm:"type:text" json:"fileUrl"` } type ConversationDetails struct { - Id string `json:"id"` - AssistantId string `json:"assistantId"` - ConversationId string `json:"conversationId"` - Prompt string `json:"prompt"` - SysPrompt string `json:"sysPrompt"` - Response string `json:"response"` - SearchList string `json:"searchList"` - QaType int32 `json:"qaType"` - FileUrl string `json:"requestFileUrls"` - FileSize int64 `json:"fileSize"` - FileName string `json:"fileName"` - FileInfo []FileInfo `json:"fileInfo"` - UserId string `json:"userId"` - OrgId string `json:"orgId"` - CreatedAt int64 `json:"createdAt"` - UpdatedAt int64 `json:"updatedAt"` + ID uint32 `gorm:"primarykey;column:id" json:"id"` + AssistantId uint32 `gorm:"column:assistant_id;comment:智能体id" json:"assistantId"` + ConversationId string `gorm:"type:varchar(255);index" json:"conversationId"` + Prompt string `gorm:"type:text" json:"prompt"` + SysPrompt string `gorm:"type:text" json:"sysPrompt"` + Response string `gorm:"type:text" json:"response"` + SearchList string `gorm:"type:text" json:"searchList"` + QaType int32 `gorm:"type:int" json:"qaType"` + FileUrl string `gorm:"type:text" json:"requestFileUrls"` + FileSize int64 `gorm:"type:bigint" json:"fileSize"` + FileName string `gorm:"type:varchar(500)" json:"fileName"` + FileInfo []FileInfo `gorm:"type:json" json:"fileInfo"` + UserId string `gorm:"type:varchar(255);index" json:"userId"` + OrgId string `gorm:"type:varchar(255);index" json:"orgId"` + CreatedAt int64 `gorm:"autoCreateTime:milli" json:"createdAt"` + UpdatedAt int64 `gorm:"autoUpdateTime:milli" json:"updatedAt"` } diff --git a/internal/assistant-service/client/orm/client.go b/internal/assistant-service/client/orm/client.go index 69a15fcc0..a8352ffeb 100644 --- a/internal/assistant-service/client/orm/client.go +++ b/internal/assistant-service/client/orm/client.go @@ -24,6 +24,7 @@ func NewClient(db *gorm.DB) (*Client, error) { model.AssistantMCP{}, model.AssistantTool{}, model.CustomPrompt{}, + model.ConversationDetails{}, model.AssistantSnapshot{}, ); err != nil { return nil, err diff --git a/internal/assistant-service/client/orm/conversation_details.go b/internal/assistant-service/client/orm/conversation_details.go new file mode 100644 index 000000000..9dd57e391 --- /dev/null +++ b/internal/assistant-service/client/orm/conversation_details.go @@ -0,0 +1,44 @@ +package orm + +import ( + "context" + + err_code "github.com/UnicomAI/wanwu/api/proto/err-code" + "github.com/UnicomAI/wanwu/internal/assistant-service/client/model" + "github.com/UnicomAI/wanwu/internal/assistant-service/client/orm/sqlopt" + "gorm.io/gorm" +) + +func (c *Client) CreateConversationDetails(ctx context.Context, details *model.ConversationDetails) *err_code.Status { + if details.ID != 0 { + return toErrStatus("assistant_conversation_details_create", "create conversation details but id not 0") + } + return c.transaction(ctx, func(tx *gorm.DB) *err_code.Status { + if err := tx.Create(details).Error; err != nil { + return toErrStatus("assistant_conversation_details_create", err.Error()) + } + return nil + }) +} + +func (c *Client) GetConversationDetailsList(ctx context.Context, conversationID, userID, orgID string, offset, limit int32) ([]*model.ConversationDetails, int64, *err_code.Status) { + var conversations []*model.ConversationDetails + var count int64 + return conversations, count, c.transaction(ctx, func(tx *gorm.DB) *err_code.Status { + query := sqlopt.DataPerm(userID, orgID).Apply(tx.Model(&model.ConversationDetails{})) + + if conversationID != "" { + query = query.Where("conversation_id = ?", conversationID) + } + + if err := query.Count(&count).Error; err != nil { + return toErrStatus("assistant_conversations_get_list", err.Error()) + } + + if err := query.Offset(int(offset)).Limit(int(limit)).Order("created_at DESC").Find(&conversations).Error; err != nil { + return toErrStatus("assistant_conversations_get_list", err.Error()) + } + + return nil + }) +} diff --git a/internal/assistant-service/config/config.go b/internal/assistant-service/config/config.go index d528ba325..ff807a17c 100644 --- a/internal/assistant-service/config/config.go +++ b/internal/assistant-service/config/config.go @@ -2,7 +2,6 @@ package config import ( "github.com/UnicomAI/wanwu/pkg/db" - "github.com/UnicomAI/wanwu/pkg/es" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/redis" "github.com/UnicomAI/wanwu/pkg/util" @@ -17,7 +16,6 @@ type Config struct { Log LogConfig `json:"log" mapstructure:"log"` DB db.Config `json:"db" mapstructure:"db"` Redis redis.Config `json:"redis" mapstructure:"redis"` - ES es.Config `json:"es" mapstructure:"es"` Assistant Assistant `json:"assistant" mapstructure:"assistant"` Minio *MinioConfig `mapstructure:"minio" json:"minio"` Knowledge Knowledge `mapstructure:"knowledge" json:"knowledge" yaml:"knowledge"` diff --git a/internal/assistant-service/server/grpc/assistant/assistant.go b/internal/assistant-service/server/grpc/assistant/assistant.go index 49db7e2e1..2b7a49f58 100644 --- a/internal/assistant-service/server/grpc/assistant/assistant.go +++ b/internal/assistant-service/server/grpc/assistant/assistant.go @@ -10,6 +10,7 @@ import ( errs "github.com/UnicomAI/wanwu/api/proto/err-code" "github.com/UnicomAI/wanwu/internal/assistant-service/client/model" "github.com/UnicomAI/wanwu/internal/assistant-service/config" + "github.com/UnicomAI/wanwu/pkg/db" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" "google.golang.org/protobuf/types/known/emptypb" @@ -142,7 +143,7 @@ func (s *Service) AssistantConfigUpdate(ctx context.Context, req *assistant_serv Args: []string{err.Error()}, }) } - existingAssistant.ModelConfig = string(modelConfigBytes) + existingAssistant.ModelConfig = db.LongText(modelConfigBytes) } // 处理rerankConfig,转换成json字符串之后再更新 @@ -162,7 +163,7 @@ func (s *Service) AssistantConfigUpdate(ctx context.Context, req *assistant_serv Args: []string{err.Error()}, }) } - existingAssistant.RerankConfig = string(rerankConfigBytes) + existingAssistant.RerankConfig = db.LongText(rerankConfigBytes) } } @@ -175,7 +176,7 @@ func (s *Service) AssistantConfigUpdate(ctx context.Context, req *assistant_serv Args: []string{err.Error()}, }) } - existingAssistant.KnowledgebaseConfig = string(knowledgeBaseConfigBytes) + existingAssistant.KnowledgebaseConfig = db.LongText(knowledgeBaseConfigBytes) log.Debugf("knowConfig = %s", existingAssistant.KnowledgebaseConfig) } @@ -188,7 +189,7 @@ func (s *Service) AssistantConfigUpdate(ctx context.Context, req *assistant_serv Args: []string{err.Error()}, }) } - existingAssistant.SafetyConfig = string(safetyConfigBytes) + existingAssistant.SafetyConfig = db.LongText(safetyConfigBytes) } // 处理visionConfig,转换成json字符串之后再更新 @@ -200,7 +201,7 @@ func (s *Service) AssistantConfigUpdate(ctx context.Context, req *assistant_serv Args: []string{err.Error()}, }) } - existingAssistant.VisionConfig = string(visionConfigBytes) + existingAssistant.VisionConfig = db.LongText(visionConfigBytes) } // 调用client方法更新智能体 diff --git a/internal/assistant-service/server/grpc/assistant/conversation.go b/internal/assistant-service/server/grpc/assistant/conversation.go index f74313838..78d05cc0d 100644 --- a/internal/assistant-service/server/grpc/assistant/conversation.go +++ b/internal/assistant-service/server/grpc/assistant/conversation.go @@ -26,7 +26,6 @@ import ( "github.com/UnicomAI/wanwu/internal/assistant-service/client/model" "github.com/UnicomAI/wanwu/internal/assistant-service/config" "github.com/UnicomAI/wanwu/pkg/constant" - "github.com/UnicomAI/wanwu/pkg/es" grpc_util "github.com/UnicomAI/wanwu/pkg/grpc-util" http_client "github.com/UnicomAI/wanwu/pkg/http-client" "github.com/UnicomAI/wanwu/pkg/log" @@ -119,37 +118,20 @@ func (s *Service) GetConversationList(ctx context.Context, req *assistant_servic func (s *Service) GetConversationDetailList(ctx context.Context, req *assistant_service.GetConversationDetailListReq) (*assistant_service.GetConversationDetailListResp, error) { // 计算分页参数 from := (req.PageNo - 1) * req.PageSize - size := int(req.PageSize) - // 组装查询条件 - fieldConditions := map[string]interface{}{ - "conversationId": req.ConversationId, - "userId.keyword": req.Identity.UserId, - "orgId.keyword": req.Identity.OrgId, - } - - // 使用通配符查询所有对话详情索引 - indexPattern := "conversation_detail_infos_*" - - // 从ES查询数据 - documents, total, err := es.Assistant().SearchByFields(ctx, indexPattern, fieldConditions, int(from), size) + // 从数据库查询数据 + details, total, err := s.cli.GetConversationDetailsList(ctx, req.ConversationId, req.Identity.UserId, req.Identity.OrgId, from, req.PageSize) if err != nil { - log.Errorf("从ES查询对话详情失败,conversationId: %s, userId: %s, error: %v", req.ConversationId, req.Identity.UserId, err) + log.Errorf("从数据库查询对话详情失败,conversationId: %s, userId: %s, error: %v", req.ConversationId, req.Identity.UserId, err) return nil, fmt.Errorf("查询对话详情失败: %v", err) } // 转换查询结果为响应格式 var conversationDetails []*assistant_service.ConversionDetailInfo - for _, doc := range documents { - var detail model.ConversationDetails - if err := json.Unmarshal(doc, &detail); err != nil { - log.Warnf("解析ES文档失败: %v", err) - continue - } - + for _, detail := range details { conversationDetails = append(conversationDetails, &assistant_service.ConversionDetailInfo{ - Id: detail.Id, - AssistantId: detail.AssistantId, + Id: strconv.Itoa(int(detail.ID)), + AssistantId: strconv.Itoa(int(detail.AssistantId)), ConversationId: detail.ConversationId, Prompt: detail.Prompt, SysPrompt: detail.SysPrompt, @@ -165,7 +147,7 @@ func (s *Service) GetConversationDetailList(ctx context.Context, req *assistant_ }) } - log.Infof("成功从ES查询对话详情,conversationId: %s, userId: %s, 总数: %d, 返回: %d", + log.Infof("成功从数据库查询对话详情,conversationId: %s, userId: %s, 总数: %d, 返回: %d", req.ConversationId, req.Identity.UserId, total, len(conversationDetails)) return &assistant_service.GetConversationDetailListResp{ @@ -207,7 +189,7 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv terminationMessage = fullResponse.String() + "\n本次回答已被终止" } - saveConversation(ctx, req, terminationMessage, searchList) + s.saveConversation(ctx, req, terminationMessage, searchList) log.Infof("因上下文取消保存终止消息,assistantId: %s, conversationId: %s", req.AssistantId, req.ConversationId) } }() @@ -227,7 +209,7 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv if status != nil { log.Errorf("Assistant服务获取智能体信息失败,assistantId: %s, error: %v", req.AssistantId, status) SSEError(stream, "智能体信息获取失败") - saveConversation(ctx, req, "智能体信息获取失败", "") + s.saveConversation(ctx, req, "智能体信息获取失败", "") return errStatus(errs.Code_AssistantConversationErr, status) } } else { @@ -235,13 +217,13 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv if status != nil { log.Errorf("Assistant服务获取智能体快照失败,assistantId: %s, error: %v", req.AssistantId, status) SSEError(stream, "智能体快照获取失败") - saveConversation(ctx, req, "智能体快照获取失败", "") + s.saveConversation(ctx, req, "智能体快照获取失败", "") return errStatus(errs.Code_AssistantConversationErr, status) } if err := jsonToStruct(assistantSnapshot.AssistantInfo, &assistant); err != nil { SSEError(stream, "智能体信息获取失败") - saveConversation(ctx, req, "智能体信息获取失败", "") + s.saveConversation(ctx, req, "智能体信息获取失败", "") return errStatus(errs.Code_AssistantErr, toErrStatus("assistant_snapshot", err.Error())) } } @@ -254,7 +236,7 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv if assistantConfig.SseUrl == "" { log.Errorf("Assistant服务SSE URL配置为空,assistantId: %s", req.AssistantId) SSEError(stream, "智能体SSE URL配置错误") - saveConversation(ctx, req, "智能体SSE URL配置错误", "") + s.saveConversation(ctx, req, "智能体SSE URL配置错误", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "SSE URL配置错误") } @@ -278,28 +260,28 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv _, err = s.setModelConfigParams(sseReq, assistant) if err != nil { SSEError(stream, "智能体模型配置解析失败") - saveConversation(ctx, req, "智能体模型配置解析失败", "") + s.saveConversation(ctx, req, "智能体模型配置解析失败", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "模型配置解析失败") } // 知识库参数配置 if err = s.setKnowledgebaseParams(ctx, sseReq, req, assistant); err != nil { SSEError(stream, "智能体知识库配置解析失败") - saveConversation(ctx, req, "智能体知识库配置解析失败", "") + s.saveConversation(ctx, req, "智能体知识库配置解析失败", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "知识库配置解析失败") } // plugin参数配置 if err := s.setToolAndWorkflowParams(ctx, sseReq, req.AssistantId, req.Identity, req.Draft, assistantSnapshot); err != nil { SSEError(stream, "智能体plugin配置错误") - saveConversation(ctx, req, "智能体plugin配置错误", "") + s.saveConversation(ctx, req, "智能体plugin配置错误", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "plugin配置错误") } // MCP 信息参数配置 if err = s.setMCPParams(ctx, sseReq, assistant, req.Draft, assistantSnapshot); err != nil { SSEError(stream, "智能体MCP配置解析失败") - saveConversation(ctx, req, "智能体MCP配置解析失败", "") + s.saveConversation(ctx, req, "智能体MCP配置解析失败", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "MCP配置解析失败") } @@ -314,13 +296,13 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv if err != nil { log.Errorf("Assistant服务序列化请求体失败,assistantId: %s, error: %v", req.AssistantId, err) SSEError(stream, "请求参数错误") - saveConversation(ctx, req, "请求参数错误", "") + s.saveConversation(ctx, req, "请求参数错误", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "请求参数错误") } if err = json.Unmarshal(reqBytes, &requestBody); err != nil { log.Errorf("Assistant服务反序列化请求体到map失败,assistantId: %s, error: %v", req.AssistantId, err) SSEError(stream, "请求参数错误") - saveConversation(ctx, req, "请求参数错误", "") + s.saveConversation(ctx, req, "请求参数错误", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "请求参数错误") } @@ -333,7 +315,7 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv if err != nil { log.Errorf("Assistant服务序列化最终请求体失败,assistantId: %s, error: %v", req.AssistantId, err) SSEError(stream, "请求参数错误") - saveConversation(ctx, req, "请求参数错误", "") + s.saveConversation(ctx, req, "请求参数错误", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "请求参数错误") } @@ -351,7 +333,7 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv log.Errorf("Assistant服务调用智能体能力接口失败,assistantId: %s, uuid: %s, error: %v", req.AssistantId, id, err) if ctx.Err() == nil { //非上下文被取消 SSEError(stream, "agent服务异常") - saveConversation(ctx, req, "agent服务异常", "") + s.saveConversation(ctx, req, "agent服务异常", "") } return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "agent服务异常") } @@ -362,7 +344,7 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv if sseResp.StatusCode > http.StatusBadRequest { log.Errorf("Assistant服务智能体能力接口返回错误状态码,assistantId: %s, statusCode: %d", req.AssistantId, sseResp.StatusCode) SSEError(stream, "agent服务异常") - saveConversation(ctx, req, "agent服务异常", "") + s.saveConversation(ctx, req, "agent服务异常", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "agent服务异常") } @@ -383,7 +365,7 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv if !req.Trial { // 只有在上下文未被取消的情况下才保存并标记为已保存 if ctx.Err() == nil { - saveConversation(ctx, req, fullResponse.String(), searchList) + s.saveConversation(ctx, req, fullResponse.String(), searchList) conversationSaved = true // 标记已保存 } // 如果上下文被取消,不设置conversationSaved,让defer函数处理终止消息 @@ -399,7 +381,7 @@ func (s *Service) AssistantConversionStream(req *assistant_service.AssistantConv if hasReadFirstMessage && fullResponse.Len() > 0 { errorMessage = fullResponse.String() + "\n" + errorMessage } - saveConversation(ctx, req, errorMessage, searchList) + s.saveConversation(ctx, req, errorMessage, searchList) conversationSaved = true // 标记已保存,避免defer中重复保存 log.Debugf("Assistant服务保存了中断消息,assistantId: %s, errorMessage: %s", req.AssistantId, errorMessage) } @@ -496,7 +478,7 @@ func (s *Service) AssistantConversionStreamNew(req *assistant_service.AssistantC terminationMessage = fullResponse.String() + "\n本次回答已被终止" } - saveConversation(ctx, req, terminationMessage, searchList) + s.saveConversation(ctx, req, terminationMessage, searchList) log.Infof("因上下文取消保存终止消息,assistantId: %s, conversationId: %s", req.AssistantId, req.ConversationId) } }() @@ -511,7 +493,7 @@ func (s *Service) AssistantConversionStreamNew(req *assistant_service.AssistantC if status != nil { log.Errorf("Assistant服务获取智能体信息失败,assistantId: %s, error: %v", req.AssistantId, status) SSEError(stream, "智能体信息获取失败") - saveConversation(ctx, req, "智能体信息获取失败", "") + s.saveConversation(ctx, req, "智能体信息获取失败", "") return errStatus(errs.Code_AssistantConversationErr, status) } } else { @@ -519,13 +501,13 @@ func (s *Service) AssistantConversionStreamNew(req *assistant_service.AssistantC if status != nil { log.Errorf("Assistant服务获取智能体快照失败,assistantId: %s, error: %v", req.AssistantId, status) SSEError(stream, "智能体快照获取失败") - saveConversation(ctx, req, "智能体快照获取失败", "") + s.saveConversation(ctx, req, "智能体快照获取失败", "") return errStatus(errs.Code_AssistantConversationErr, status) } if err := jsonToStruct(assistantSnapshot.AssistantInfo, &assistant); err != nil { SSEError(stream, "智能体信息获取失败") - saveConversation(ctx, req, "智能体信息获取失败", "") + s.saveConversation(ctx, req, "智能体信息获取失败", "") return errStatus(errs.Code_AssistantErr, toErrStatus("assistant_snapshot", err.Error())) } } @@ -537,7 +519,7 @@ func (s *Service) AssistantConversionStreamNew(req *assistant_service.AssistantC if assistantConfig.SseUrl == "" { log.Errorf("Assistant服务SSE URL配置为空,assistantId: %s", req.AssistantId) SSEError(stream, "智能体SSE URL配置错误") - saveConversation(ctx, req, "智能体SSE URL配置错误", "") + s.saveConversation(ctx, req, "智能体SSE URL配置错误", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "SSE URL配置错误") } @@ -558,28 +540,28 @@ func (s *Service) AssistantConversionStreamNew(req *assistant_service.AssistantC _, err := s.setModelConfigParams(sseReq, assistant) if err != nil { SSEError(stream, "智能体模型配置解析失败") - saveConversation(ctx, req, "智能体模型配置解析失败", "") + s.saveConversation(ctx, req, "智能体模型配置解析失败", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "模型配置解析失败") } // 知识库参数配置 if err = s.setKnowledgebaseParams(ctx, sseReq, req, assistant); err != nil { SSEError(stream, "智能体知识库配置解析失败") - saveConversation(ctx, req, "智能体知识库配置解析失败", "") + s.saveConversation(ctx, req, "智能体知识库配置解析失败", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "知识库配置解析失败") } // plugin参数配置 if err := s.setToolAndWorkflowParamsNew(ctx, sseReq, req.AssistantId, req.Identity, req.Draft, assistantSnapshot); err != nil { SSEError(stream, "智能体plugin配置错误") - saveConversation(ctx, req, "智能体plugin配置错误", "") + s.saveConversation(ctx, req, "智能体plugin配置错误", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "plugin配置错误") } // MCP 信息参数配置 if err = s.setMCPParams(ctx, sseReq, assistant, req.Draft, assistantSnapshot); err != nil { SSEError(stream, "智能体MCP配置解析失败") - saveConversation(ctx, req, "智能体MCP配置解析失败", "") + s.saveConversation(ctx, req, "智能体MCP配置解析失败", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "MCP配置解析失败") } @@ -596,7 +578,7 @@ func (s *Service) AssistantConversionStreamNew(req *assistant_service.AssistantC if err != nil { log.Errorf("Assistant服务序列化请求体失败,assistantId: %s, error: %v", req.AssistantId, err) SSEError(stream, "请求参数错误") - saveConversation(ctx, req, "请求参数错误", "") + s.saveConversation(ctx, req, "请求参数错误", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "请求参数错误") } @@ -614,7 +596,7 @@ func (s *Service) AssistantConversionStreamNew(req *assistant_service.AssistantC log.Errorf("Assistant服务调用智能体能力接口失败,assistantId: %s, uuid: %s, error: %v", req.AssistantId, id, err) if ctx.Err() == nil { //非上下文被取消 SSEError(stream, "agent服务异常") - saveConversation(ctx, req, "agent服务异常", "") + s.saveConversation(ctx, req, "agent服务异常", "") } return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "agent服务异常") } @@ -625,7 +607,7 @@ func (s *Service) AssistantConversionStreamNew(req *assistant_service.AssistantC if sseResp.StatusCode > http.StatusBadRequest { log.Errorf("Assistant服务智能体能力接口返回错误状态码,assistantId: %s, statusCode: %d", req.AssistantId, sseResp.StatusCode) SSEError(stream, "agent服务异常") - saveConversation(ctx, req, "agent服务异常", "") + s.saveConversation(ctx, req, "agent服务异常", "") return grpc_util.ErrorStatusWithKey(errs.Code_AssistantConversationErr, "assistant_conversation", "agent服务异常") } @@ -647,7 +629,7 @@ func (s *Service) AssistantConversionStreamNew(req *assistant_service.AssistantC if !req.Trial { // 只有在上下文未被取消的情况下才保存并标记为已保存 if ctx.Err() == nil { - saveConversation(ctx, req, fullResponse.String(), searchList) + s.saveConversation(ctx, req, fullResponse.String(), searchList) conversationSaved = true // 标记已保存 } // 如果上下文被取消,不设置conversationSaved,让defer函数处理终止消息 @@ -663,7 +645,7 @@ func (s *Service) AssistantConversionStreamNew(req *assistant_service.AssistantC if hasReadFirstMessage && fullResponse.Len() > 0 { errorMessage = fullResponse.String() + "\n" + errorMessage } - saveConversation(ctx, req, errorMessage, searchList) + s.saveConversation(ctx, req, errorMessage, searchList) conversationSaved = true // 标记已保存,避免defer中重复保存 log.Debugf("Assistant服务保存了中断消息,assistantId: %s, errorMessage: %s", req.AssistantId, errorMessage) } @@ -942,26 +924,14 @@ func (s *Service) setMCPParams(ctx context.Context, sseReq *config.AgentSSEReque // 设置历史记录参数 func (s *Service) setHistoryParams(ctx context.Context, sseReq *config.AgentSSERequest, req *assistant_service.AssistantConversionStreamReq) { - fieldConditions := map[string]interface{}{ - "conversationId": req.ConversationId, - "userId": req.Identity.UserId, - "orgId": req.Identity.OrgId, - } - indexPattern := "conversation_detail_infos_*" - - documents, _, err := es.Assistant().SearchByFields(ctx, indexPattern, fieldConditions, 0, 1000) + details, _, err := s.cli.GetConversationDetailsList(ctx, req.ConversationId, req.Identity.UserId, req.Identity.OrgId, 0, 1000) if err != nil { log.Warnf("Assistant服务查询历史聊天记录失败,conversationId: %s, userId: %s, error: %v", req.ConversationId, req.Identity.UserId, err) return } var historyList []config.AssistantConversionHistory - for _, doc := range documents { - var detail model.ConversationDetails - if err := json.Unmarshal(doc, &detail); err != nil { - log.Warnf("Assistant服务解析ES历史聊天记录失败: %v", err) - continue - } + for _, detail := range details { history := config.AssistantConversionHistory{ Query: detail.Prompt, UploadFileUrl: extractFileUrlsFromModel(detail.FileInfo), @@ -996,23 +966,39 @@ func buildRerank(req *assistant_service.AssistantConversionStreamReq, knowledgeb } // 使用独立上下文保存对话的辅助函数 -func saveConversation(originalCtx context.Context, req *assistant_service.AssistantConversionStreamReq, response, searchList string) { +func (s *Service) saveConversation(originalCtx context.Context, req *assistant_service.AssistantConversionStreamReq, response, searchList string) { + var saveCtx context.Context + // 如果原始上下文已取消,创建一个新的独立上下文 if originalCtx.Err() != nil { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + newCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - - if err := saveConversationDetailToES(ctx, req, response, searchList); err != nil { - log.Errorf("保存聊天记录到ES失败,assistantId: %s, conversationId: %s, error: %v", - req.AssistantId, req.ConversationId, err) - } - return + saveCtx = newCtx + } else { + saveCtx = originalCtx } - // 原始上下文未取消时,继续使用它 - if err := saveConversationDetailToES(originalCtx, req, response, searchList); err != nil { - log.Errorf("保存聊天记录到ES失败,assistantId: %s, conversationId: %s, error: %v", - req.AssistantId, req.ConversationId, err) + // 组装ConversationDetails数据 + nowMilli := time.Now().UnixMilli() + conversationDetail := &model.ConversationDetails{ + AssistantId: util.MustU32(req.AssistantId), + ConversationId: req.ConversationId, + Prompt: req.Prompt, + FileInfo: extractFileInfos(req.FileInfo), + Response: response, + SearchList: searchList, + UserId: req.Identity.UserId, + OrgId: req.Identity.OrgId, + CreatedAt: nowMilli, + UpdatedAt: nowMilli, + } + // 写入数据库 + if status := s.cli.CreateConversationDetails(saveCtx, conversationDetail); status != nil { + log.Errorf("保存聊天记录到数据库失败,assistantId: %s, conversationId: %s, error: %v", + req.AssistantId, req.ConversationId, status.TextKey) + } else { + log.Infof("成功保存聊天记录到数据库,assistantId: %s, conversationId: %s", + req.AssistantId, req.ConversationId) } } @@ -1493,38 +1479,6 @@ func HttpRequestLlmStream(ctx context.Context, url, userId, xuid string, body io return response, err } -// saveConversationDetailToES 保存聊天记录到ES -func saveConversationDetailToES(ctx context.Context, req *assistant_service.AssistantConversionStreamReq, response, searchList string) error { - // 根据当前时间生成索引名称,格式为conversation_detail_infos_YYYYMM - now := time.Now() - indexName := fmt.Sprintf("conversation_detail_infos_%d%02d", now.Year(), now.Month()) - - // 组装ConversationDetails数据 - nowMilli := now.UnixMilli() - conversationDetail := &model.ConversationDetails{ - Id: uuid.New().String(), - AssistantId: req.AssistantId, - ConversationId: req.ConversationId, - Prompt: req.Prompt, - FileInfo: extractFileInfos(req.FileInfo), - Response: response, - SearchList: searchList, - UserId: req.Identity.UserId, - OrgId: req.Identity.OrgId, - CreatedAt: nowMilli, - UpdatedAt: nowMilli, - } - - // 写入ES - if err := es.Assistant().IndexDocument(ctx, indexName, conversationDetail); err != nil { - return fmt.Errorf("写入ES失败: %v", err) - } - - log.Infof("成功保存聊天记录到ES,索引: %s, assistantId: %s, conversationId: %s", - indexName, req.AssistantId, req.ConversationId) - return nil -} - // ConversationDeleteByAssistantId 根据智能体ID删除对话 func (s *Service) ConversationDeleteByAssistantId(ctx context.Context, req *assistant_service.ConversationDeleteByAssistantIdReq) (*emptypb.Empty, error) { if status := s.cli.DeleteConversationByAssistantID(ctx, req.AssistantId, req.Identity.UserId, req.Identity.OrgId); status != nil { diff --git a/internal/knowledge-service/client/model/doc_segment_import_task.go b/internal/knowledge-service/client/model/doc_segment_import_task.go index efe996921..068e82258 100644 --- a/internal/knowledge-service/client/model/doc_segment_import_task.go +++ b/internal/knowledge-service/client/model/doc_segment_import_task.go @@ -1,5 +1,7 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + const ( DocSegmentImportInit = 0 //任务待处理 DocSegmentImportImporting = 1 //文档分段导入中 @@ -26,18 +28,18 @@ type ChildChunkConfig struct { } type DocSegmentImportTask struct { - Id uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id';" json:"id"` - ImportId string `gorm:"uniqueIndex:idx_unique_import_id;column:import_id;type:varchar(64)" json:"importId"` // Business Primary Key - DocId string `gorm:"column:doc_id;type:varchar(64);not null;index:idx_doc_id" json:"docId"` - Status int `gorm:"column:status;type:tinyint(1);not null;comment:'0-任务待处理;1-任务导入中 ;2-任务完成;3-任务失败'" json:"status"` - SuccessCount int `gorm:"column:success_count;type:bigint(20);default:0;comment:'成功数量'" json:"successCount"` - TotalCount int `gorm:"column:total_count;type:bigint(20);default:0;comment:'导入数量,当在导入过程中出现重启,则total为0'" json:"totalCount"` - ErrorMsg string `gorm:"column:error_msg;type:longtext;not null;comment:'解析的错误信息'" json:"errorMsg"` - ImportParams string `gorm:"column:import_params;type:text;not null;comment:'导入信息'" json:"importParams"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);not null;autoCreateTime:milli" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);not null;autoUpdateTime:milli" json:"updateAt"` // Update Time - UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` - OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` + Id uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id';" json:"id"` + ImportId string `gorm:"uniqueIndex:idx_unique_import_id;column:import_id;type:varchar(64)" json:"importId"` // Business Primary Key + DocId string `gorm:"column:doc_id;type:varchar(64);not null;index:idx_doc_id" json:"docId"` + Status int `gorm:"column:status;not null;comment:'0-任务待处理;1-任务导入中 ;2-任务完成;3-任务失败'" json:"status"` + SuccessCount int `gorm:"column:success_count;type:bigint;default:0;comment:'成功数量'" json:"successCount"` + TotalCount int `gorm:"column:total_count;type:bigint;default:0;comment:'导入数量,当在导入过程中出现重启,则total为0'" json:"totalCount"` + ErrorMsg db.LongText `gorm:"column:error_msg;not null;comment:'解析的错误信息'" json:"errorMsg"` + ImportParams string `gorm:"column:import_params;type:text;not null;comment:'导入信息'" json:"importParams"` + CreatedAt int64 `gorm:"column:create_at;type:bigint;not null;autoCreateTime:milli" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;not null;autoUpdateTime:milli" json:"updateAt"` // Update Time + UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` + OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` } func (DocSegmentImportTask) TableName() string { diff --git a/internal/knowledge-service/client/model/knowledge.go b/internal/knowledge-service/client/model/knowledge.go index de37e6958..386be93f1 100644 --- a/internal/knowledge-service/client/model/knowledge.go +++ b/internal/knowledge-service/client/model/knowledge.go @@ -1,5 +1,7 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + type ReportStatus int const ( @@ -15,25 +17,25 @@ const ( ) type KnowledgeBase struct { - Id uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id';" json:"id"` // Primary Key + Id uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id';" json:"id"` // Primary Key KnowledgeId string `gorm:"uniqueIndex:idx_unique_knowledge_id;column:knowledge_id;type:varchar(64)" json:"knowledgeId"` // Business Primary Key Name string `gorm:"column:name;index:idx_user_id_name,priority:2;type:varchar(256);not null;default:''" json:"name"` RagName string `gorm:"column:rag_name;type:varchar(256);not null;default:''" json:"ragName"` - Category int `gorm:"column:category;index:idx_category;type:tinyint(4);not null;default:0;comment:'0-知识库,1-问答库';" json:"category"` + Category int `gorm:"column:category;index:idx_category;not null;default:0;comment:'0-知识库,1-问答库';" json:"category"` Description string `gorm:"column:description;type:text;comment:'知识库描述';" json:"description"` - DocCount int `gorm:"column:doc_count;type:int(11);not null;default:0;comment:'文档数量';" json:"docCount"` - ShareCount int `gorm:"column:share_count;type:int(11);not null;default:0;comment:'文档共享数量';" json:"shareCount"` - DocSize int64 `gorm:"column:doc_size;type:bigint(20);not null;default:0;comment:'文档大小单位:字节';" json:"docSize"` - EmbeddingModel string `gorm:"column:embedding_model;type:longtext;not null;comment:'embedding模型信息';" json:"embeddingModel"` - KnowledgeGraphSwitch int `gorm:"column:knowledge_graph_switch;type:tinyint(1);not null;default:0;comment:'知识图谱开关,方便查询过滤,0:关闭,1:开启';" json:"knowledgeGraphSwitch"` - KnowledgeGraph string `gorm:"column:knowledge_graph;type:longtext;not null;comment:'知识图谱配置';" json:"knowledgeGraph"` - ReportCreateCount int `gorm:"column:report_create_count;type:int(11);not null;default:0;comment:'社区报告生成数量'" json:"reportCreateCount"` - ReportStatus ReportStatus `gorm:"column:report_status;type:int(11);not null;comment:'0-待处理, 120- 生成成功, 130-生成中,121-社区报告加载图谱失败,122-生成社区报告失败,123-社区报告持久化存储失败,预留120~140';" json:"reportStatus"` - CreatedAt int64 `gorm:"column:create_at;autoCreateTime:milli;type:bigint(20);not null;" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;autoUpdateTime:milli;type:bigint(20);not null;" json:"updateAt"` // Update Time + DocCount int `gorm:"column:doc_count;type:int;not null;default:0;comment:'文档数量';" json:"docCount"` + ShareCount int `gorm:"column:share_count;type:int;not null;default:0;comment:'文档共享数量';" json:"shareCount"` + DocSize int64 `gorm:"column:doc_size;type:bigint;not null;default:0;comment:'文档大小单位:字节';" json:"docSize"` + EmbeddingModel db.LongText `gorm:"column:embedding_model;not null;comment:'embedding模型信息';" json:"embeddingModel"` + KnowledgeGraphSwitch int `gorm:"column:knowledge_graph_switch;not null;default:0;comment:'知识图谱开关,方便查询过滤,0:关闭,1:开启';" json:"knowledgeGraphSwitch"` + KnowledgeGraph db.LongText `gorm:"column:knowledge_graph;not null;comment:'知识图谱配置';" json:"knowledgeGraph"` + ReportCreateCount int `gorm:"column:report_create_count;type:int;not null;default:0;comment:'社区报告生成数量'" json:"reportCreateCount"` + ReportStatus ReportStatus `gorm:"column:report_status;type:int;not null;comment:'0-待处理, 120- 生成成功, 130-生成中,121-社区报告加载图谱失败,122-生成社区报告失败,123-社区报告持久化存储失败,预留120~140';" json:"reportStatus"` + CreatedAt int64 `gorm:"column:create_at;autoCreateTime:milli;type:bigint;not null;" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;autoUpdateTime:milli;type:bigint;not null;" json:"updateAt"` // Update Time UserId string `gorm:"column:user_id;index:idx_user_id_name,priority:1;type:varchar(64);not null;default:'';" json:"userId"` OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:'';" json:"orgId"` - Deleted int `gorm:"column:deleted;type:tinyint(1);not null;default:0;comment:'是否逻辑删除';" json:"deleted"` + Deleted int `gorm:"column:deleted;not null;default:0;comment:'是否逻辑删除';" json:"deleted"` } func (KnowledgeBase) TableName() string { diff --git a/internal/knowledge-service/client/model/knowledge_doc.go b/internal/knowledge-service/client/model/knowledge_doc.go index d22023212..c4864bd49 100644 --- a/internal/knowledge-service/client/model/knowledge_doc.go +++ b/internal/knowledge-service/client/model/knowledge_doc.go @@ -1,5 +1,7 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + type GraphStatus int const ( @@ -20,8 +22,8 @@ const ( ) type KnowledgeDoc struct { - Id uint32 `json:"id" gorm:"primary_key;type:bigint(20) auto_increment;not null;comment:'id';"` // Primary Key - DocId string `gorm:"uniqueIndex:idx_unique_doc_id;column:doc_id;type:varchar(64)" json:"docId"` // Business Primary Key + Id uint32 `json:"id" gorm:"primary_key;type:bigint auto_increment;not null;comment:'id';"` // Primary Key + DocId string `gorm:"uniqueIndex:idx_unique_doc_id;column:doc_id;type:varchar(64)" json:"docId"` // Business Primary Key ImportTaskId string `gorm:"column:batch_id;type:varchar(64);not null;default:'';comment:'导入的任务id'" json:"importTaskId"` KnowledgeId string `gorm:"column:knowledge_id;index:idx_user_id_knowledge_id_name,priority:2;index:idx_user_id_knowledge_id_tag,priority:2;type:varchar(64);not null;default:''" json:"knowledgeId"` FilePathMd5 string `gorm:"column:file_path_md5;type:varchar(64);not null;default:'';comment:'文件的md5值'" json:"filePathMd5"` @@ -29,15 +31,15 @@ type KnowledgeDoc struct { DirFilePath string `gorm:"column:dir_file_path;type:text;not null;comment:'文件在文件夹中的相对目录'" json:"dirFilePath"` Name string `gorm:"column:name;index:idx_user_id_knowledge_id_name,priority:3;type:varchar(256);not null;default:''" json:"name"` FileType string `gorm:"column:file_type;type:varchar(20);not null;default:''" json:"fileType"` - FileSize int64 `gorm:"column:file_size;type:bigint(20);COMMENT:'文件大小,单位byte'" json:"fileSize"` - Status int `gorm:"column:status;type:tinyint(1);not null;comment:'0-待处理, 1- 处理完成, 2-正在审核中(目前没有),3-正在解析中,4-审核未通过(目前没有),5-解析失败';" json:"status"` - GraphStatus GraphStatus `gorm:"column:graph_status;type:int(11);not null;comment:'0-待处理, 100- 生成成功, 101-生成图谱获取chunk文本失败,102-提取图谱失败,103-图谱持久化存储失败,预留100~120';" json:"graphStatus"` - ErrorMsg string `gorm:"column:error_msg;type:longtext;not null;comment:'解析的错误信息'" json:"errorMsg"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);autoCreateTime:milli;not null;" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time + FileSize int64 `gorm:"column:file_size;type:bigint;COMMENT:'文件大小,单位byte'" json:"fileSize"` + Status int `gorm:"column:status;not null;comment:'0-待处理, 1- 处理完成, 2-正在审核中(目前没有),3-正在解析中,4-审核未通过(目前没有),5-解析失败';" json:"status"` + GraphStatus GraphStatus `gorm:"column:graph_status;type:int;not null;comment:'0-待处理, 100- 生成成功, 101-生成图谱获取chunk文本失败,102-提取图谱失败,103-图谱持久化存储失败,预留100~120';" json:"graphStatus"` + ErrorMsg db.LongText `gorm:"column:error_msg;not null;comment:'解析的错误信息'" json:"errorMsg"` + CreatedAt int64 `gorm:"column:create_at;type:bigint;autoCreateTime:milli;not null;" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time UserId string `gorm:"column:user_id;index:idx_user_id_knowledge_id_name,priority:1;index:idx_user_id_knowledge_id_tag,priority:1;type:varchar(64);not null;default:'';" json:"userId"` OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` - Deleted int `gorm:"column:deleted;type:tinyint(1);not null;default:0;comment:'是否逻辑删除';" json:"deleted"` + Deleted int `gorm:"column:deleted;not null;default:0;comment:'是否逻辑删除';" json:"deleted"` } func (KnowledgeDoc) TableName() string { diff --git a/internal/knowledge-service/client/model/knowledge_doc_meta.go b/internal/knowledge-service/client/model/knowledge_doc_meta.go index 6d5c02c27..1ab1db604 100644 --- a/internal/knowledge-service/client/model/knowledge_doc_meta.go +++ b/internal/knowledge-service/client/model/knowledge_doc_meta.go @@ -7,7 +7,7 @@ const ( ) type KnowledgeDocMeta struct { - Id uint32 `json:"id" gorm:"primary_key;type:bigint(20) auto_increment;not null;comment:'id';"` // Primary Key + Id uint32 `json:"id" gorm:"primary_key;type:bigint auto_increment;not null;comment:'id';"` // Primary Key KnowledgeId string `gorm:"index:idx_knowledge_id;index:idx_knowledge_id_value_main,priority:1;column:knowledge_id;type:varchar(64);not null;default:''" json:"knowledgeId"` MetaId string `gorm:"uniqueIndex:idx_unique_meta_id;column:meta_id;type:varchar(64)" json:"metaId"` DocId string `gorm:"index:idx_doc_id;column:doc_id;type:varchar(64)" json:"docId"` @@ -16,8 +16,8 @@ type KnowledgeDocMeta struct { ValueMain string `gorm:"index:idx_knowledge_id_value_main,priority:2;column:value_main;type:varchar(128);not null;default:'';comment:'替代原value字段'" json:"valueMain"` ValueType string `gorm:"column:value_type;type:varchar(64);not null;default:'string';comment:'string,number,time'" json:"valueType"` Rule string `gorm:"column:rule;type:text;not null" json:"rule"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);autoCreateTime:milli;not null;" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time + CreatedAt int64 `gorm:"column:create_at;type:bigint;autoCreateTime:milli;not null;" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time UserId string `gorm:"column:user_id;index:idx_user_id_knowledge_id_name,priority:1;index:idx_user_id_knowledge_id_tag,priority:1;type:varchar(64);not null;default:'';" json:"userId"` OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` } diff --git a/internal/knowledge-service/client/model/knowledge_export_task.go b/internal/knowledge-service/client/model/knowledge_export_task.go index c08faa31e..12d9424c1 100644 --- a/internal/knowledge-service/client/model/knowledge_export_task.go +++ b/internal/knowledge-service/client/model/knowledge_export_task.go @@ -1,5 +1,7 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + const ( KnowledgeExportInit = 0 //任务待处理 KnowledgeExportExporting = 1 //导出中 @@ -13,20 +15,20 @@ type KnowledgeExportTaskParams struct { } type KnowledgeExportTask struct { - Id uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id';" json:"id"` - ExportId string `gorm:"uniqueIndex:idx_unique_export_id;column:export_id;type:varchar(64)" json:"exportId"` // Business Primary Key - KnowledgeId string `gorm:"column:knowledge_id;type:varchar(64);not null;index:idx_knowledge_id" json:"knowledgeId"` - ExportFilePath string `gorm:"column:export_file_path;type:text;not null;comment:'导出文件地址'" json:"exportFilePath"` - ExportFileSize int64 `gorm:"column:export_file_size;type:bigint(20);not null;comment:'导出文件大小'" json:"exportFileSize"` - Status int `gorm:"column:status;type:tinyint(1);not null;comment:'0-任务待处理;1-任务导出中 ;2-任务完成;3-任务失败'" json:"status"` - SuccessCount int `gorm:"column:success_count;type:bigint(20);default:0;comment:'成功数量'" json:"successCount"` - TotalCount int `gorm:"column:total_count;type:bigint(20);default:0;comment:'导出数量,当在导出过程中出现重启,则total为0'" json:"totalCount"` - ErrorMsg string `gorm:"column:error_msg;type:longtext;not null;comment:'导出的错误信息'" json:"errorMsg"` - ExportParams string `gorm:"column:export_params;type:text;not null;comment:'导出信息'" json:"exportParams"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);not null;autoCreateTime:milli" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);not null;autoUpdateTime:milli" json:"updateAt"` // Update Time - UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` - OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` + Id uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id';" json:"id"` + ExportId string `gorm:"uniqueIndex:idx_unique_export_id;column:export_id;type:varchar(64)" json:"exportId"` // Business Primary Key + KnowledgeId string `gorm:"column:knowledge_id;type:varchar(64);not null;index:idx_knowledge_id" json:"knowledgeId"` + ExportFilePath string `gorm:"column:export_file_path;type:text;not null;comment:'导出文件地址'" json:"exportFilePath"` + ExportFileSize int64 `gorm:"column:export_file_size;type:bigint;not null;comment:'导出文件大小'" json:"exportFileSize"` + Status int `gorm:"column:status;not null;comment:'0-任务待处理;1-任务导出中 ;2-任务完成;3-任务失败'" json:"status"` + SuccessCount int `gorm:"column:success_count;type:bigint;default:0;comment:'成功数量'" json:"successCount"` + TotalCount int `gorm:"column:total_count;type:bigint;default:0;comment:'导出数量,当在导出过程中出现重启,则total为0'" json:"totalCount"` + ErrorMsg db.LongText `gorm:"column:error_msg;not null;comment:'导出的错误信息'" json:"errorMsg"` + ExportParams string `gorm:"column:export_params;type:text;not null;comment:'导出信息'" json:"exportParams"` + CreatedAt int64 `gorm:"column:create_at;type:bigint;not null;autoCreateTime:milli" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;not null;autoUpdateTime:milli" json:"updateAt"` // Update Time + UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` + OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` } func (KnowledgeExportTask) TableName() string { diff --git a/internal/knowledge-service/client/model/knowledge_import_task.go b/internal/knowledge-service/client/model/knowledge_import_task.go index 54243e1aa..d31170cad 100644 --- a/internal/knowledge-service/client/model/knowledge_import_task.go +++ b/internal/knowledge-service/client/model/knowledge_import_task.go @@ -1,5 +1,7 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + const ( KnowledgeImportAnalyze = 1 //知识库任务解析中 KnowledgeImportSubmit = 2 //知识库任务已提交 @@ -59,23 +61,23 @@ type DocMetaData struct { } type KnowledgeImportTask struct { - Id uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id';" json:"id"` - ImportId string `gorm:"uniqueIndex:idx_unique_import_id;column:import_id;type:varchar(64)" json:"importId"` // Business Primary Key - KnowledgeId string `gorm:"column:knowledge_id;type:varchar(64);not null;index:idx_knowledge_id" json:"knowledgeId"` - ImportType int `gorm:"column:import_type;type:tinyint(1);not null;" json:"importType"` - TaskType int `gorm:"column:task_type;type:tinyint(1);not null;default:0;comment:'0:创建导入,1:配置更新'" json:"taskType"` - Status int `gorm:"column:status;type:tinyint(1);not null;comment:'0-任务待处理;1-任务解析中 ;2-任务提交算法完成;3-任务完成;4-任务失败" json:"status"` - ErrorMsg string `gorm:"column:error_msg;type:longtext;not null;comment:'解析的错误信息'" json:"errorMsg"` - DocInfo string `gorm:"column:doc_info;type:longtext;not null;comment:'文件信息'" json:"docInfo"` - SegmentConfig string `gorm:"column:segment_config;type:text;not null;comment:'分段配置信息'" json:"segmentConfig"` - DocAnalyzer string `gorm:"column:doc_analyzer;type:text;not null;comment:'文档解析配置'" json:"docAnalyzer"` - OcrModelId string `gorm:"column:ocr_model_id;type:varchar(64);not null;default:'';comment:'ocr模型id'" json:"ocrModelId"` - DocPreProcess string `gorm:"column:doc_pre_process;type:text;not null;comment:'文档预处理规则: replace_symbols,delete_links'" json:"docPreProcess"` - MetaData string `gorm:"column:meta_data;type:text;not null;comment:'元数据列表'" json:"metaData"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);autoCreateTime:milli;not null;" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time - UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` - OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` + Id uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id';" json:"id"` + ImportId string `gorm:"uniqueIndex:idx_unique_import_id;column:import_id;type:varchar(64)" json:"importId"` // Business Primary Key + KnowledgeId string `gorm:"column:knowledge_id;type:varchar(64);not null;index:idx_knowledge_id" json:"knowledgeId"` + ImportType int `gorm:"column:import_type;not null;" json:"importType"` + TaskType int `gorm:"column:task_type;type:tinyint(1);not null;default:0;comment:'0:创建导入,1:配置更新'" json:"taskType"` + Status int `gorm:"column:status;not null;comment:'0-任务待处理;1-任务解析中 ;2-任务提交算法完成;3-任务完成;4-任务失败" json:"status"` + ErrorMsg db.LongText `gorm:"column:error_msg;not null;comment:'解析的错误信息'" json:"errorMsg"` + DocInfo db.LongText `gorm:"column:doc_info;not null;comment:'文件信息'" json:"docInfo"` + SegmentConfig string `gorm:"column:segment_config;type:text;not null;comment:'分段配置信息'" json:"segmentConfig"` + DocAnalyzer string `gorm:"column:doc_analyzer;type:text;not null;comment:'文档解析配置'" json:"docAnalyzer"` + OcrModelId string `gorm:"column:ocr_model_id;type:varchar(64);not null;default:'';comment:'ocr模型id'" json:"ocrModelId"` + DocPreProcess string `gorm:"column:doc_pre_process;type:text;not null;comment:'文档预处理规则: replace_symbols,delete_links'" json:"docPreProcess"` + MetaData string `gorm:"column:meta_data;type:text;not null;comment:'元数据列表'" json:"metaData"` + CreatedAt int64 `gorm:"column:create_at;type:bigint;autoCreateTime:milli;not null;" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time + UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` + OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` } func (KnowledgeImportTask) TableName() string { diff --git a/internal/knowledge-service/client/model/knowledge_keywords.go b/internal/knowledge-service/client/model/knowledge_keywords.go index db0549a5b..702ba02d2 100644 --- a/internal/knowledge-service/client/model/knowledge_keywords.go +++ b/internal/knowledge-service/client/model/knowledge_keywords.go @@ -2,7 +2,7 @@ package model // KnowledgeKeywords 知识库关键词映射表 type KnowledgeKeywords struct { - Id uint32 `json:"id" gorm:"primary_key;type:bigint(20) auto_increment;not null;comment:'id';"` // Primary Key + Id uint32 `json:"id" gorm:"primary_key;type:bigint auto_increment;not null;comment:'id';"` // Primary Key Name string `json:"name" gorm:"column:name;type:varchar(255);comment:专名词"` Alias string `json:"alias" gorm:"column:alias;type:varchar(255);comment:别名"` KnowledgeBaseIds string `json:"knowledgeBaseIds" gorm:"column:knowledge_base_ids;type:text;comment:关联的知识库id;内容格式为:[\"2\",\"3\"]"` diff --git a/internal/knowledge-service/client/model/knowledge_permission.go b/internal/knowledge-service/client/model/knowledge_permission.go index b43f987f3..884c8d82a 100644 --- a/internal/knowledge-service/client/model/knowledge_permission.go +++ b/internal/knowledge-service/client/model/knowledge_permission.go @@ -10,14 +10,14 @@ const ( // KnowledgePermission 业务唯一所以,一个知识库,一个用户,一个组织 只能有一条 type KnowledgePermission struct { - Id uint32 `json:"id" gorm:"primary_key;type:bigint(20) auto_increment;not null;comment:'id';"` // Primary Key + Id uint32 `json:"id" gorm:"primary_key;type:bigint auto_increment;not null;comment:'id';"` // Primary Key PermissionId string `gorm:"column:permission_id;uniqueIndex:idx_unique_permission_id;type:varchar(64);not null;default:''" json:"permissionId"` KnowledgeId string `gorm:"column:knowledge_id;uniqueIndex:idx_knowledge_id_org_user,priority:1;type:varchar(64);not null;default:''" json:"knowledgeId"` GrantUserId string `gorm:"column:grant_user_id;type:varchar(64);not null;default:'';comment:'有权限的用户id';" json:"permissionUserId"` GrantOrgId string `gorm:"column:grant_org_id;type:varchar(64);not null;default:''comment:'有权限的组织id';" json:"permissionOrgId"` - PermissionType int `gorm:"column:permission_type;type:tinyint(1);not null;default:0;comment:'权限类型0:读权限,10:编辑权限 20:授权权限,一个知识库只有一个人有授权权限'" json:"permissionType"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);autoCreateTime:milli;not null;" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time + PermissionType int `gorm:"column:permission_type;not null;default:0;comment:'权限类型0:读权限,10:编辑权限 20:授权权限,一个知识库只有一个人有授权权限'" json:"permissionType"` + CreatedAt int64 `gorm:"column:create_at;type:bigint;autoCreateTime:milli;not null;" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time OrgId string `gorm:"column:org_id;uniqueIndex:idx_knowledge_id_org_user,priority:2;type:varchar(64);not null;default:'';" json:"orgId"` UserId string `gorm:"column:user_id;uniqueIndex:idx_knowledge_id_org_user,priority:3;type:varchar(64);not null;default:'';" json:"userId"` } diff --git a/internal/knowledge-service/client/model/knowledge_permission_record.go b/internal/knowledge-service/client/model/knowledge_permission_record.go index 24ca258e1..cdc585b9b 100644 --- a/internal/knowledge-service/client/model/knowledge_permission_record.go +++ b/internal/knowledge-service/client/model/knowledge_permission_record.go @@ -8,18 +8,18 @@ const ( // KnowledgePermissionRecord 业务唯一所以,一个知识库,一个用户,一个组织 只能有一条 type KnowledgePermissionRecord struct { - Id uint32 `json:"id" gorm:"primary_key;type:bigint(20) auto_increment;not null;comment:'id';"` // Primary Key + Id uint32 `json:"id" gorm:"primary_key;type:bigint auto_increment;not null;comment:'id';"` // Primary Key RecordId string `gorm:"column:record_id;uniqueIndex:idx_record_id;type:varchar(64);not null;default:''" json:"recordId"` KnowledgeId string `gorm:"column:knowledge_id;;type:varchar(64);not null;default:''" json:"knowledgeId"` - Option int `gorm:"column:option;type:tinyint(1);not null;default:0;comment:'操作类型0:添加权限,1:删除权限, 2:修改权限'" json:"option"` + Option int `gorm:"column:option;not null;default:0;comment:'操作类型0:添加权限,1:删除权限, 2:修改权限'" json:"option"` OperatorUserId string `gorm:"column:operator_user_id;type:varchar(64);not null;default:'';comment:'有权限的用户id';" json:"operatorUserId"` OperatorOrgId string `gorm:"column:operator_org_id;type:varchar(64);not null;default:''comment:'有权限的组织id';" json:"operatorOrgId"` - FromPermissionType int `gorm:"column:from_permission_type;type:tinyint(1);not null;default:0;comment:'权限类型-1:无权限,0:读权限,10:编辑权限 20:授权权限,一个知识库只有一个人有授权权限'" json:"fromPermissionType"` - ToPermissionType int `gorm:"column:to_permission_type;type:tinyint(1);not null;default:0;comment:'权限类型-1:无权限,0:读权限,10:编辑权限 20:授权权限,一个知识库只有一个人有授权权限'" json:"toPermissionType"` + FromPermissionType int `gorm:"column:from_permission_type;not null;default:0;comment:'权限类型-1:无权限,0:读权限,10:编辑权限 20:授权权限,一个知识库只有一个人有授权权限'" json:"fromPermissionType"` + ToPermissionType int `gorm:"column:to_permission_type;not null;default:0;comment:'权限类型-1:无权限,0:读权限,10:编辑权限 20:授权权限,一个知识库只有一个人有授权权限'" json:"toPermissionType"` OwnerOrgId string `gorm:"column:owner_org_id;type:varchar(64);not null;default:'';" json:"ownerOrgId"` OwnerUserId string `gorm:"column:owner_user_id;type:varchar(64);not null;default:'';" json:"ownerUserId"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);autoCreateTime:milli;not null;" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time + CreatedAt int64 `gorm:"column:create_at;type:bigint;autoCreateTime:milli;not null;" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time } diff --git a/internal/knowledge-service/client/model/knowledge_qa_pair.go b/internal/knowledge-service/client/model/knowledge_qa_pair.go index 0fd7efb4c..a8855f0d8 100644 --- a/internal/knowledge-service/client/model/knowledge_qa_pair.go +++ b/internal/knowledge-service/client/model/knowledge_qa_pair.go @@ -1,25 +1,27 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + const ( QAPairSuccess = 2 // 问答对文件导入成功 ) type KnowledgeQAPair struct { - Id uint32 `json:"id" gorm:"primary_key;type:bigint(20) auto_increment;not null;comment:'id';"` // Primary Key - QAPairId string `gorm:"uniqueIndex:idx_unique_qa_pair_id;column:qa_pair_id;type:varchar(64)" json:"qaPairId"` // Business Primary Key - ImportTaskId string `gorm:"column:import_id;type:varchar(64);not null;default:'';comment:'导入的任务id'" json:"importId"` - KnowledgeId string `gorm:"column:knowledge_id;uniqueIndex:idx_knowledge_id_md5,priority:1;;type:varchar(64);not null;default:''" json:"knowledgeId"` - Question string `gorm:"column:question;type:longtext;not null;comment:'问题'" json:"question"` - Answer string `gorm:"column:answer;type:longtext;not null;comment:'答案'" json:"answer"` - Switch bool `gorm:"column:switch;type:tinyint(1);not null;default:0;comment:'开关'" json:"switch"` - Status int `gorm:"column:status;type:tinyint(1);not null;comment:'0-待处理, 1-导入中,2-导入成功,3-导入失败';" json:"status"` - ErrorMsg string `gorm:"column:error_msg;type:longtext;not null;comment:'导入状态错误信息'" json:"errorMsg"` - QuestionMd5 string `gorm:"column:question_md5;uniqueIndex:idx_knowledge_id_md5,priority:2;type:varchar(64);not null;default:'';comment:'问题的md5值'" json:"questionMd5"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);not null;autoCreateTime:milli;" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);not null;autoCreateTime:milli;" json:"updateAt"` // Update Time - UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` - OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` - Deleted int `gorm:"column:deleted;type:tinyint(1);not null;default:0;comment:'是否逻辑删除';" json:"deleted"` + Id uint32 `json:"id" gorm:"primary_key;type:bigint auto_increment;not null;comment:'id';"` // Primary Key + QAPairId string `gorm:"uniqueIndex:idx_unique_qa_pair_id;column:qa_pair_id;type:varchar(64)" json:"qaPairId"` // Business Primary Key + ImportTaskId string `gorm:"column:import_id;type:varchar(64);not null;default:'';comment:'导入的任务id'" json:"importId"` + KnowledgeId string `gorm:"column:knowledge_id;uniqueIndex:idx_knowledge_id_md5,priority:1;;type:varchar(64);not null;default:''" json:"knowledgeId"` + Question db.LongText `gorm:"column:question;not null;comment:'问题'" json:"question"` + Answer db.LongText `gorm:"column:answer;not null;comment:'答案'" json:"answer"` + Switch bool `gorm:"column:switch;not null;default:0;comment:'开关'" json:"switch"` + Status int `gorm:"column:status;not null;comment:'0-待处理, 1-导入中,2-导入成功,3-导入失败';" json:"status"` + ErrorMsg db.LongText `gorm:"column:error_msg;not null;comment:'导入状态错误信息'" json:"errorMsg"` + QuestionMd5 string `gorm:"column:question_md5;uniqueIndex:idx_knowledge_id_md5,priority:2;type:varchar(64);not null;default:'';comment:'问题的md5值'" json:"questionMd5"` + CreatedAt int64 `gorm:"column:create_at;type:bigint;not null;autoCreateTime:milli;" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;not null;autoCreateTime:milli;" json:"updateAt"` // Update Time + UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` + OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` + Deleted int `gorm:"column:deleted;not null;default:0;comment:'是否逻辑删除';" json:"deleted"` } func (KnowledgeQAPair) TableName() string { diff --git a/internal/knowledge-service/client/model/knowledge_qa_pair_import_task.go b/internal/knowledge-service/client/model/knowledge_qa_pair_import_task.go index f820b25fa..096acf3df 100644 --- a/internal/knowledge-service/client/model/knowledge_qa_pair_import_task.go +++ b/internal/knowledge-service/client/model/knowledge_qa_pair_import_task.go @@ -1,5 +1,7 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + const ( KnowledgeQAPairImportInit = 0 //任务待处理 KnowledgeQAPairImportImporting = 1 //导入中 @@ -8,18 +10,18 @@ const ( ) type KnowledgeQAPairImportTask struct { - Id uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id';" json:"id"` - ImportId string `gorm:"uniqueIndex:idx_unique_import_id;column:import_id;type:varchar(64)" json:"importId"` // Business Primary Key - KnowledgeId string `gorm:"column:knowledge_id;type:varchar(64);not null;index:idx_knowledge_id" json:"knowledgeId"` - DocInfo string `gorm:"column:doc_info;type:longtext;not null;comment:'文件信息'" json:"docInfo"` - Status int `gorm:"column:status;type:tinyint(1);not null;comment:'0-任务待处理;1-任务导入中 ;2-任务完成;3-任务失败'" json:"status"` - SuccessCount int `gorm:"column:success_count;type:bigint(20);default:0;comment:'成功数量'" json:"successCount"` - TotalCount int `gorm:"column:total_count;type:bigint(20);default:0;comment:'导入数量,当在导入过程中出现重启,则total为0'" json:"totalCount"` - ErrorMsg string `gorm:"column:error_msg;type:longtext;not null;comment:'导入的错误信息'" json:"errorMsg"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);not null;autoCreateTime:milli" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);not null;autoUpdateTime:milli" json:"updateAt"` // Update Time - UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` - OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` + Id uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id';" json:"id"` + ImportId string `gorm:"uniqueIndex:idx_unique_import_id;column:import_id;type:varchar(64)" json:"importId"` // Business Primary Key + KnowledgeId string `gorm:"column:knowledge_id;type:varchar(64);not null;index:idx_knowledge_id" json:"knowledgeId"` + DocInfo db.LongText `gorm:"column:doc_info;not null;comment:'文件信息'" json:"docInfo"` + Status int `gorm:"column:status;not null;comment:'0-任务待处理;1-任务导入中 ;2-任务完成;3-任务失败'" json:"status"` + SuccessCount int `gorm:"column:success_count;type:bigint;default:0;comment:'成功数量'" json:"successCount"` + TotalCount int `gorm:"column:total_count;type:bigint;default:0;comment:'导入数量,当在导入过程中出现重启,则total为0'" json:"totalCount"` + ErrorMsg db.LongText `gorm:"column:error_msg;not null;comment:'导入的错误信息'" json:"errorMsg"` + CreatedAt int64 `gorm:"column:create_at;type:bigint;not null;autoCreateTime:milli" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;not null;autoUpdateTime:milli" json:"updateAt"` // Update Time + UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` + OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` } func (KnowledgeQAPairImportTask) TableName() string { diff --git a/internal/knowledge-service/client/model/knowledge_report_import_task.go b/internal/knowledge-service/client/model/knowledge_report_import_task.go index 297dcc5dc..7d3c5010e 100644 --- a/internal/knowledge-service/client/model/knowledge_report_import_task.go +++ b/internal/knowledge-service/client/model/knowledge_report_import_task.go @@ -1,5 +1,7 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + const ( KnowledgeReportImportInit = 0 //任务待处理 KnowledgeReportImportImporting = 1 //文档分段导入中 @@ -12,18 +14,18 @@ type KnowledgeReportImportParams struct { } type KnowledgeReportImportTask struct { - Id uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id';" json:"id"` - ImportId string `gorm:"uniqueIndex:idx_unique_import_id;column:import_id;type:varchar(64)" json:"importId"` // Business Primary Key - KnowledgeId string `gorm:"column:knowledge_id;type:varchar(64);not null;index:idx_knowledge_id" json:"knowledgeId"` - Status int `gorm:"column:status;type:tinyint(1);not null;comment:'0-任务待处理;1-任务导入中 ;2-任务完成;3-任务失败'" json:"status"` - SuccessCount int `gorm:"column:success_count;type:bigint(20);default:0;comment:'成功数量'" json:"successCount"` - TotalCount int `gorm:"column:total_count;type:bigint(20);default:0;comment:'导入数量,当在导入过程中出现重启,则total为0'" json:"totalCount"` - ErrorMsg string `gorm:"column:error_msg;type:longtext;not null;comment:'解析的错误信息'" json:"errorMsg"` - ImportParams string `gorm:"column:import_params;type:text;not null;comment:'导入信息'" json:"importParams"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);not null;autoCreateTime:milli" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);not null;autoUpdateTime:milli" json:"updateAt"` // Update Time - UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` - OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` + Id uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id';" json:"id"` + ImportId string `gorm:"uniqueIndex:idx_unique_import_id;column:import_id;type:varchar(64)" json:"importId"` // Business Primary Key + KnowledgeId string `gorm:"column:knowledge_id;type:varchar(64);not null;index:idx_knowledge_id" json:"knowledgeId"` + Status int `gorm:"column:status;not null;comment:'0-任务待处理;1-任务导入中 ;2-任务完成;3-任务失败'" json:"status"` + SuccessCount int `gorm:"column:success_count;type:bigint;default:0;comment:'成功数量'" json:"successCount"` + TotalCount int `gorm:"column:total_count;type:bigint;default:0;comment:'导入数量,当在导入过程中出现重启,则total为0'" json:"totalCount"` + ErrorMsg db.LongText `gorm:"column:error_msg;not null;comment:'解析的错误信息'" json:"errorMsg"` + ImportParams string `gorm:"column:import_params;type:text;not null;comment:'导入信息'" json:"importParams"` + CreatedAt int64 `gorm:"column:create_at;type:bigint;not null;autoCreateTime:milli" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;not null;autoUpdateTime:milli" json:"updateAt"` // Update Time + UserId string `gorm:"column:user_id;type:varchar(64);not null;default:'';" json:"userId"` + OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:''" json:"orgId"` } func (KnowledgeReportImportTask) TableName() string { diff --git a/internal/knowledge-service/client/model/knowledge_splitter.go b/internal/knowledge-service/client/model/knowledge_splitter.go index 13baadc00..1ecff2270 100644 --- a/internal/knowledge-service/client/model/knowledge_splitter.go +++ b/internal/knowledge-service/client/model/knowledge_splitter.go @@ -1,12 +1,12 @@ package model type KnowledgeSplitter struct { - Id uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id';" json:"id"` // Primary Key + Id uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id';" json:"id"` // Primary Key SplitterId string `gorm:"uniqueIndex:idx_unique_splitter_id;column:splitter_id;type:varchar(64)" json:"splitterId"` // Business Primary Key Name string `gorm:"column:name;index:idx_user_id_name,priority:2;type:varchar(64);not null;default:''" json:"name"` Value string `gorm:"column:value;type:varchar(64);not null;default:''" json:"value"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);autoCreateTime:milli;not null;" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time + CreatedAt int64 `gorm:"column:create_at;type:bigint;autoCreateTime:milli;not null;" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time UserId string `gorm:"column:user_id;index:idx_user_id_name,priority:1;type:varchar(64);not null;default:'';" json:"userId"` OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:'';" json:"orgId"` } diff --git a/internal/knowledge-service/client/model/knowledge_tag.go b/internal/knowledge-service/client/model/knowledge_tag.go index 97a5887a3..baa2bfd3f 100644 --- a/internal/knowledge-service/client/model/knowledge_tag.go +++ b/internal/knowledge-service/client/model/knowledge_tag.go @@ -1,11 +1,11 @@ package model type KnowledgeTag struct { - Id uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id';" json:"id"` // Primary Key - TagId string `gorm:"uniqueIndex:idx_unique_tag_id;column:tag_id;type:varchar(64)" json:"tagId"` // Business Primary Key + Id uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id';" json:"id"` // Primary Key + TagId string `gorm:"uniqueIndex:idx_unique_tag_id;column:tag_id;type:varchar(64)" json:"tagId"` // Business Primary Key Name string `gorm:"column:name;index:idx_user_id_name,priority:2;type:varchar(64);not null;default:''" json:"name"` - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);autoCreateTime:milli;not null;" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time + CreatedAt int64 `gorm:"column:create_at;type:bigint;autoCreateTime:milli;not null;" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time UserId string `gorm:"column:user_id;index:idx_user_id_name,priority:1;type:varchar(64);not null;default:'';" json:"userId"` OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:'';" json:"orgId"` } diff --git a/internal/knowledge-service/client/model/knowledge_tag_relation.go b/internal/knowledge-service/client/model/knowledge_tag_relation.go index 04be23d0f..2a699a5ea 100644 --- a/internal/knowledge-service/client/model/knowledge_tag_relation.go +++ b/internal/knowledge-service/client/model/knowledge_tag_relation.go @@ -1,11 +1,11 @@ package model type KnowledgeTagRelation struct { - Id uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id';" json:"id"` // Primary Key + Id uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id';" json:"id"` // Primary Key TagId string `gorm:"column:tag_id;index:idx_tag_id;type:varchar(64);not null;default:'';" json:"tagId"` // tagId KnowledgeId string `gorm:"column:knowledge_id;index:idx_knowledge_id;type:varchar(64);not null;default:'';" json:"knowledgeId"` // knowledgeId - CreatedAt int64 `gorm:"column:create_at;type:bigint(20);autoCreateTime:milli;not null;" json:"createAt"` // Create Time - UpdatedAt int64 `gorm:"column:update_at;type:bigint(20);autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time + CreatedAt int64 `gorm:"column:create_at;type:bigint;autoCreateTime:milli;not null;" json:"createAt"` // Create Time + UpdatedAt int64 `gorm:"column:update_at;type:bigint;autoUpdateTime:milli;not null;" json:"updateAt"` // Update Time UserId string `gorm:"column:user_id;index:idx_user_id_tag_name,priority:1;type:varchar(64);not null;default:'';" json:"userId"` OrgId string `gorm:"column:org_id;type:varchar(64);not null;default:'';" json:"orgId"` } diff --git a/internal/knowledge-service/client/orm/knowledge.go b/internal/knowledge-service/client/orm/knowledge.go index b6a7068f4..d50d5cb1d 100644 --- a/internal/knowledge-service/client/orm/knowledge.go +++ b/internal/knowledge-service/client/orm/knowledge.go @@ -319,7 +319,7 @@ func CreateKnowledgeReport(ctx context.Context, knowledgeId string) error { return err } //构造知识库图谱 - knowledgeGraph := BuildKnowledgeGraph(knowledge.KnowledgeGraph) + knowledgeGraph := BuildKnowledgeGraph(string(knowledge.KnowledgeGraph)) //2.通知rag生成社区报告 return service.RagCreateKnowledgeReport(ctx, &service.RagImportDocParams{ KnowledgeName: knowledge.RagName, diff --git a/internal/knowledge-service/client/orm/knowledge_doc.go b/internal/knowledge-service/client/orm/knowledge_doc.go index e359a5e32..f24b69ae6 100644 --- a/internal/knowledge-service/client/orm/knowledge_doc.go +++ b/internal/knowledge-service/client/orm/knowledge_doc.go @@ -309,7 +309,7 @@ func CreateKnowledgeDoc(ctx context.Context, doc *model.KnowledgeDoc, importTask return nil } //构造知识库图谱 - knowledgeGraph := BuildKnowledgeGraph(knowledge.KnowledgeGraph) + knowledgeGraph := BuildKnowledgeGraph(string(knowledge.KnowledgeGraph)) //2.rag文档导入 return service.RagImportDoc(ctx, &service.RagImportDocParams{ DocId: doc.DocId, @@ -377,7 +377,7 @@ func ReImportKnowledgeDoc(ctx context.Context, doc *model.KnowledgeDoc, importTa return nil } //构造知识库图谱 - knowledgeGraph := BuildKnowledgeGraph(knowledge.KnowledgeGraph) + knowledgeGraph := BuildKnowledgeGraph(string(knowledge.KnowledgeGraph)) //2.rag文档导入 return service.RagImportDoc(ctx, &service.RagImportDocParams{ DocId: doc.DocId, @@ -681,7 +681,7 @@ func CreateKnowledgeUrlDoc(ctx context.Context, doc *model.KnowledgeDoc, importT //3.rag 文档开始导入操作 var fileName = service.RebuildFileName(doc.DocId, doc.FileType, doc.Name) //构造知识库图谱 - knowledgeGraph := BuildKnowledgeGraph(knowledge.KnowledgeGraph) + knowledgeGraph := BuildKnowledgeGraph(string(knowledge.KnowledgeGraph)) return service.RagImportDoc(ctx, &service.RagImportDocParams{ DocId: doc.DocId, KnowledgeName: knowledge.RagName, diff --git a/internal/knowledge-service/pkg/async-task/async_task_client.go b/internal/knowledge-service/pkg/async-task/async_task_client.go index aee27d7f4..590534ea7 100644 --- a/internal/knowledge-service/pkg/async-task/async_task_client.go +++ b/internal/knowledge-service/pkg/async-task/async_task_client.go @@ -3,10 +3,10 @@ package async_task import ( "context" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_component/pending" "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg" "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/db" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_component/pending" ) var asyncTaskClient = AsyncTaskClient{} diff --git a/internal/knowledge-service/pkg/config/config.go b/internal/knowledge-service/pkg/config/config.go index 582c02b21..6e3ce9191 100644 --- a/internal/knowledge-service/pkg/config/config.go +++ b/internal/knowledge-service/pkg/config/config.go @@ -61,7 +61,9 @@ type Config struct { RpcLog LogConfig `mapstructure:"rpc-log" json:"rpc-log" yaml:"rpc-log"` DB db.Config `json:"db" mapstructure:"db"` Minio *MinioConfig `mapstructure:"minio" json:"minio"` + Topic *TopicConfig `mapstructure:"topic" json:"topic"` Kafka *KafkaConfig `mapstructure:"kafka" json:"kafka"` + Redis *RedisConfig `mapstructure:"redis" json:"redis"` // 新增Redis配置 UsageLimit *UsageLimitConfig `mapstructure:"usage-limit" json:"usageLimit"` RagServer *RagServerConfig `mapstructure:"rag-server" json:"ragServer"` KnowledgeDocConfig *KnowledgeDocConfig `json:"knowledge-doc-config" mapstructure:"knowledge-doc-config"` @@ -96,17 +98,40 @@ type MinioConfig struct { PublicExportBucket string `mapstructure:"public-export-bucket" json:"public-export-bucket"` } -type KafkaConfig struct { - Addr string `mapstructure:"addr" json:"addr"` - User string `mapstructure:"user" json:"user"` - Password string `mapstructure:"password" json:"password"` +type TopicConfig struct { UrlAnalysisTopic string `mapstructure:"url-analysis-topic" json:"url-analysis-topic"` UrlImportTopic string `mapstructure:"url-import-topic" json:"url-import-topic"` Topic string `mapstructure:"topic" json:"topic"` KnowledgeGraphTopic string `mapstructure:"knowledge-graph-topic" json:"knowledge-graph-topic"` +} + +type KafkaConfig struct { + Enabled bool `mapstructure:"enabled" json:"enabled"` + Addr string `mapstructure:"addr" json:"addr"` + User string `mapstructure:"user" json:"user"` + Password string `mapstructure:"password" json:"password"` DefaultPartitionNum int32 `mapstructure:"default-partition-num" json:"defaultPartitionNum"` } +type RedisConfig struct { + Enabled bool `mapstructure:"enabled" json:"enabled"` + Mode string `mapstructure:"mode" json:"mode"` // standalone, sentinel, cluster + Addr []string `mapstructure:"addr" json:"addr"` // 单节点: ["host:port"]; 哨兵/集群: ["host1:port1", "host2:port2", ...] + MasterName string `mapstructure:"master-name" json:"masterName"` // 哨兵模式专用 + Password string `mapstructure:"password" json:"password"` + DB int `mapstructure:"db" json:"db"` // 仅standalone模式有效 + PoolSize int `mapstructure:"pool-size" json:"poolSize"` + MinIdleConns int `mapstructure:"min-idle-conns" json:"minIdleConns"` + MaxRetries int `mapstructure:"max-retries" json:"maxRetries"` + DialTimeout int `mapstructure:"dial-timeout" json:"dialTimeout"` // 秒 + ReadTimeout int `mapstructure:"read-timeout" json:"readTimeout"` // 秒 + WriteTimeout int `mapstructure:"write-timeout" json:"writeTimeout"` // 秒 + IdleTimeout int `mapstructure:"idle-timeout" json:"idleTimeout"` // 秒 + // Stream 配置 + StreamMaxLen int64 `mapstructure:"stream-max-len" json:"streamMaxLen"` // Stream 最大长度 + StreamApproxMaxLen bool `mapstructure:"stream-approx-max-len" json:"streamApproxMaxLen"` // 是否使用近似最大长度 +} + type UsageLimitConfig struct { DocTotal int64 `mapstructure:"doc-total" json:"docTotal"` FileTypes string `mapstructure:"file-types" json:"fileTypes"` diff --git a/internal/knowledge-service/pkg/db/init.go b/internal/knowledge-service/pkg/db/init.go index 01bf15159..96cc60bb9 100644 --- a/internal/knowledge-service/pkg/db/init.go +++ b/internal/knowledge-service/pkg/db/init.go @@ -15,7 +15,6 @@ import ( ) const ( - knowledgeDBName = "knowledge_base_service" docMetaTimestampOld = "1757692799000" //2025-09-12 23:59:59 ) @@ -47,11 +46,6 @@ func (c DataBaseClient) Load() error { log.Errorf("init knowledge_base_service db err: %v", err) return err } - //创建数据库配置 - err = createDB(dbHandle) - if err != nil { - return err - } //注册表配置 err = registerTables(dbHandle) if err != nil { @@ -86,17 +80,6 @@ func (c DataBaseClient) Stop() error { return err } -// 创建db -func createDB(dbClient *gorm.DB) error { - err := dbClient.Exec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci;", knowledgeDBName)).Error - if err != nil { - log.Errorf("MySQL创建数据库%s异常: %v", knowledgeDBName, err) - return err - } - log.Infof("MySQL创建数据库成功: %s", knowledgeDBName) - return nil -} - // 注册表信息 func registerTables(dbClient *gorm.DB) error { err := dbClient.AutoMigrate( diff --git a/internal/knowledge-service/pkg/mq/kafka.go b/internal/knowledge-service/pkg/mq/kafka.go index bbcb0dbd9..f55e71674 100644 --- a/internal/knowledge-service/pkg/mq/kafka.go +++ b/internal/knowledge-service/pkg/mq/kafka.go @@ -28,6 +28,12 @@ func (c Kafka) LoadType() string { } func (c Kafka) Load() error { + if !config.GetConfig().Kafka.Enabled { + log.Infof("Kafka is not enabled, skip init") + return nil + } + + log.Infof("Kafka is enabled, start init") admin, err := initKafkaAdmin() if err != nil { return err @@ -68,7 +74,7 @@ func initKafkaAdmin() (sarama.ClusterAdmin, error) { func initKafka(kafkaAdmin sarama.ClusterAdmin) (sarama.SyncProducer, error) { log.Infof("开始初始化Kafka配置") defaultPartitionNum := config.GetConfig().Kafka.DefaultPartitionNum - var defaultTopic = config.GetConfig().Kafka.Topic + var defaultTopic = config.GetConfig().Topic.Topic kafkaConfig := sarama.NewConfig() kafkaConfig.ClientID = util.GenUUID() kafkaConfig.Version = sarama.MaxVersion @@ -137,7 +143,7 @@ func initKafka(kafkaAdmin sarama.ClusterAdmin) (sarama.SyncProducer, error) { return producer, nil } -func SendMessage(msg interface{}, topic string) error { +func sendMessageToKafka(msg interface{}, topic string) error { if msg == nil { return errors.New("message is nil") } diff --git a/internal/knowledge-service/pkg/mq/mq.go b/internal/knowledge-service/pkg/mq/mq.go new file mode 100644 index 000000000..a191bd0fb --- /dev/null +++ b/internal/knowledge-service/pkg/mq/mq.go @@ -0,0 +1,15 @@ +package mq + +import ( + "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/config" +) + +func SendMessage(msg interface{}, topic string) error { + cfg := config.GetConfig() + if cfg.Kafka.Enabled { + return sendMessageToKafka(msg, topic) + } else if cfg.Redis.Enabled { + return sendMessageToRedis(msg, topic) + } + return nil +} diff --git a/internal/knowledge-service/pkg/mq/redis.go b/internal/knowledge-service/pkg/mq/redis.go new file mode 100644 index 000000000..357a33a8a --- /dev/null +++ b/internal/knowledge-service/pkg/mq/redis.go @@ -0,0 +1,298 @@ +package mq + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg" + "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/config" + "github.com/UnicomAI/wanwu/pkg/log" + "github.com/redis/go-redis/v9" +) + +var redisClient = RedisClient{} + +type RedisClient struct { + Client redis.UniversalClient + trimTicker *time.Ticker + trimStopChan chan bool + trimRunning bool +} + +func init() { + pkg.AddContainer(redisClient) +} + +func (c RedisClient) LoadType() string { + return "redis" +} + +func (c RedisClient) Load() error { + if !config.GetConfig().Redis.Enabled { + log.Infof("Redis is not enabled, skip init") + return nil + } + + log.Infof("Redis is enabled, start init") + client, err := initRedisClient() + if err != nil { + return err + } + redisClient.Client = client + + // 测试连接 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + log.Errorf("Redis连接测试失败: %v", err) + return err + } + + log.Infof("Redis客户端初始化成功") + startTrimTask() + + return nil +} + +func (c RedisClient) Stop() error { + if !config.GetConfig().Redis.Enabled { + return nil + } + + stopTrimTask() + + if redisClient.Client != nil { + return redisClient.Client.Close() + } + return nil +} + +func (c RedisClient) StopPriority() int { + return pkg.DefaultPriority +} + +func initRedisClient() (redis.UniversalClient, error) { + cfg := config.GetConfig().Redis + log.Infof("开始初始化Redis客户端,模式: %s", cfg.Mode) + + var client redis.UniversalClient + var err error + + switch cfg.Mode { + case "standalone": + client, err = initStandaloneClient(cfg) + case "sentinel": + client, err = initSentinelClient(cfg) + case "cluster": + client, err = initClusterClient(cfg) + default: + return nil, fmt.Errorf("不支持的Redis模式: %s,支持的模式: standalone, sentinel, cluster", cfg.Mode) + } + + if err != nil { + log.Errorf("Redis客户端初始化失败: %v", err) + return nil, err + } + + log.Infof("Redis客户端创建成功,模式: %s", cfg.Mode) + return client, nil +} + +func initStandaloneClient(cfg *config.RedisConfig) (*redis.Client, error) { + if len(cfg.Addr) == 0 { + return nil, errors.New("standalone模式需要至少一个地址") + } + + // 取第一个地址作为standalone地址 + addr := cfg.Addr[0] + + client := redis.NewClient(&redis.Options{ + Addr: addr, + Password: cfg.Password, + DB: cfg.DB, + DialTimeout: time.Duration(cfg.DialTimeout) * time.Second, + ReadTimeout: time.Duration(cfg.ReadTimeout) * time.Second, + WriteTimeout: time.Duration(cfg.WriteTimeout) * time.Second, + PoolSize: cfg.PoolSize, + MinIdleConns: cfg.MinIdleConns, + MaxRetries: cfg.MaxRetries, + ConnMaxIdleTime: time.Duration(cfg.IdleTimeout) * time.Second, + }) + + return client, nil +} + +func initSentinelClient(cfg *config.RedisConfig) (*redis.Client, error) { + if len(cfg.Addr) == 0 { + return nil, errors.New("sentinel模式需要至少一个哨兵地址") + } + if cfg.MasterName == "" { + return nil, errors.New("sentinel模式需要指定master-name") + } + + client := redis.NewFailoverClient(&redis.FailoverOptions{ + MasterName: cfg.MasterName, + SentinelAddrs: cfg.Addr, + Password: cfg.Password, + DB: cfg.DB, + DialTimeout: time.Duration(cfg.DialTimeout) * time.Second, + ReadTimeout: time.Duration(cfg.ReadTimeout) * time.Second, + WriteTimeout: time.Duration(cfg.WriteTimeout) * time.Second, + PoolSize: cfg.PoolSize, + MinIdleConns: cfg.MinIdleConns, + MaxRetries: cfg.MaxRetries, + ConnMaxIdleTime: time.Duration(cfg.IdleTimeout) * time.Second, + }) + + return client, nil +} + +func initClusterClient(cfg *config.RedisConfig) (*redis.ClusterClient, error) { + if len(cfg.Addr) == 0 { + return nil, errors.New("cluster模式需要至少一个节点地址") + } + + // 注意:cluster模式下DB参数无效,Redis集群只支持DB 0 + client := redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: cfg.Addr, + Password: cfg.Password, + DialTimeout: time.Duration(cfg.DialTimeout) * time.Second, + ReadTimeout: time.Duration(cfg.ReadTimeout) * time.Second, + WriteTimeout: time.Duration(cfg.WriteTimeout) * time.Second, + PoolSize: cfg.PoolSize, + MinIdleConns: cfg.MinIdleConns, + MaxRetries: cfg.MaxRetries, + ConnMaxIdleTime: time.Duration(cfg.IdleTimeout) * time.Second, + }) + + return client, nil +} + +// sendMessageToRedis 发送消息到Redis Stream +func sendMessageToRedis(msg interface{}, streamKey string) error { + if msg == nil { + return errors.New("message is nil") + } + + message, err := json.Marshal(msg) + if err != nil { + return err + } + + ctx := context.Background() + + // 使用XAdd命令将消息添加到Stream + // 格式: XADD streamKey * field value + // * 表示让Redis自动生成消息ID + result, err := redisClient.Client.XAdd(ctx, &redis.XAddArgs{ + Stream: streamKey, + Values: map[string]interface{}{ + "data": message, + "timestamp": time.Now().Unix(), + }, + }).Result() + + if err != nil { + log.Errorf("Redis Stream发送消息失败, stream: %s, error: %v", streamKey, err) + return err + } + + log.Infof("Redis Stream发送成功, stream: %s, messageId: %s, data: %s", + streamKey, result, message) + return nil +} + +// startTrimTask 启动定时修剪任务 +func startTrimTask() { + cfg := config.GetConfig().Redis + + // 检查是否需要启用定时修剪 + if cfg.StreamMaxLen <= 0 { + log.Infof("StreamMaxLen配置为0或负数,不启用定时修剪任务") + return + } + + // 初始化通道和ticker + redisClient.trimStopChan = make(chan bool) + redisClient.trimRunning = true + + // 启动定时任务(每分钟执行一次) + redisClient.trimTicker = time.NewTicker(1 * time.Minute) + + go func() { + log.Infof("Redis Stream定时修剪任务已启动") + + for { + select { + case <-redisClient.trimTicker.C: + // 执行修剪任务 + performTrimTask() + case <-redisClient.trimStopChan: + log.Infof("Redis Stream定时修剪任务已停止") + return + } + } + }() +} + +// stopTrimTask 停止定时修剪任务 +func stopTrimTask() { + if redisClient.trimRunning { + redisClient.trimRunning = false + + // 停止ticker + if redisClient.trimTicker != nil { + redisClient.trimTicker.Stop() + } + + // 发送停止信号 + if redisClient.trimStopChan != nil { + redisClient.trimStopChan <- true + close(redisClient.trimStopChan) + } + } +} + +// performTrimTask 执行修剪任务 +func performTrimTask() { + cfg := config.GetConfig().Redis + + // 获取配置中的Stream最大长度 + maxLen := cfg.StreamMaxLen + if maxLen <= 0 { + return + } + + ctx := context.Background() + + topics := []string{ + config.GetConfig().Topic.UrlAnalysisTopic, + config.GetConfig().Topic.UrlImportTopic, + config.GetConfig().Topic.Topic, + config.GetConfig().Topic.KnowledgeGraphTopic, + } + // 根据配置决定是否使用近似修剪 + var err error + if cfg.StreamApproxMaxLen { + // 使用MAXLEN近似修剪,性能更好 + for _, topic := range topics { + _, err = redisClient.Client.XTrimMaxLenApprox(ctx, topic, maxLen, 0).Result() + if err != nil { + log.Warnf("Failed to trim Redis stream: %v", err) + } + } + } else { + // 使用精确修剪 + for _, topic := range topics { + _, err = redisClient.Client.XTrimMaxLen(ctx, topic, maxLen).Result() + if err != nil { + log.Warnf("Failed to trim Redis stream: %v", err) + } + } + } + +} diff --git a/internal/knowledge-service/server/grpc/knowledge/knowledge.go b/internal/knowledge-service/server/grpc/knowledge/knowledge.go index 0567669f9..826bd5c04 100644 --- a/internal/knowledge-service/server/grpc/knowledge/knowledge.go +++ b/internal/knowledge-service/server/grpc/knowledge/knowledge.go @@ -8,11 +8,13 @@ import ( "strconv" "time" + "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/db" + db2 "github.com/UnicomAI/wanwu/pkg/db" + errs "github.com/UnicomAI/wanwu/api/proto/err-code" knowledgebase_service "github.com/UnicomAI/wanwu/api/proto/knowledgebase-service" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" - "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/db" "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/util" rag_service "github.com/UnicomAI/wanwu/internal/knowledge-service/service" "github.com/UnicomAI/wanwu/pkg/log" @@ -387,7 +389,7 @@ func buildExportRecordListResp(knowledge *model.KnowledgeBase, list []*model.Kno retList = append(retList, &knowledgebase_service.ExportRecordInfo{ ExportRecordId: item.ExportId, Status: int32(item.Status), - ErrorMsg: item.ErrorMsg, + ErrorMsg: string(item.ErrorMsg), FilePath: item.ExportFilePath, UserId: item.UserId, ExportTime: wanwu_util.Time2Str(item.CreatedAt), @@ -740,7 +742,7 @@ func checkRepeatedMetaKey(metaList []*model.KnowledgeDocMeta) []*model.Knowledge func buildKnowledgeInfo(knowledge *model.KnowledgeBase) *knowledgebase_service.KnowledgeInfo { embeddingModelInfo := &knowledgebase_service.EmbeddingModelInfo{} _ = json.Unmarshal([]byte(knowledge.EmbeddingModel), embeddingModelInfo) - graph := orm.BuildKnowledgeGraph(knowledge.KnowledgeGraph) + graph := orm.BuildKnowledgeGraph(string(knowledge.KnowledgeGraph)) docCount := knowledge.DocCount if docCount < 0 { docCount = 0 @@ -793,8 +795,8 @@ func buildKnowledgeBaseModel(req *knowledgebase_service.CreateKnowledgeReq) (*mo Description: req.Description, OrgId: req.OrgId, UserId: req.UserId, - EmbeddingModel: string(embeddingModelInfo), - KnowledgeGraph: string(knowledgeGraph), + EmbeddingModel: db2.LongText(embeddingModelInfo), + KnowledgeGraph: db2.LongText(knowledgeGraph), KnowledgeGraphSwitch: buildKnowledgeGraphSwitch(req.KnowledgeGraph.Switch), CreatedAt: time.Now().UnixMilli(), UpdatedAt: time.Now().UnixMilli(), diff --git a/internal/knowledge-service/server/grpc/knowledge_doc/knowledge_doc.go b/internal/knowledge-service/server/grpc/knowledge_doc/knowledge_doc.go index 37d34702f..40f498c39 100644 --- a/internal/knowledge-service/server/grpc/knowledge_doc/knowledge_doc.go +++ b/internal/knowledge-service/server/grpc/knowledge_doc/knowledge_doc.go @@ -16,6 +16,7 @@ import ( "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/util" "github.com/UnicomAI/wanwu/internal/knowledge-service/service" import_service "github.com/UnicomAI/wanwu/internal/knowledge-service/task/import-service" + db2 "github.com/UnicomAI/wanwu/pkg/db" "github.com/UnicomAI/wanwu/pkg/log" pkg_util "github.com/UnicomAI/wanwu/pkg/util" util2 "github.com/UnicomAI/wanwu/pkg/util" @@ -542,7 +543,7 @@ func (s *Service) GetDocCategoryUploadTip(ctx context.Context, req *knowledgebas return &knowledgebase_doc_service.DocImportTipResp{ KnowledgeId: req.KnowledgeId, KnowledgeName: knowledge.Name, - Message: "\n" + task.ErrorMsg, + Message: string("\n" + task.ErrorMsg), UploadStatus: DocImportError, }, nil } else if task.Status == model.KnowledgeImportFinish { @@ -697,7 +698,7 @@ func buildDocInfo(item *model.KnowledgeDoc, segmentConfigMap map[string]*model.S KnowledgeId: item.KnowledgeId, UploadTime: util2.Time2Str(item.CreatedAt), Status: int32(util.BuildDocRespStatus(item.Status)), - ErrorMsg: item.ErrorMsg, + ErrorMsg: string(item.ErrorMsg), SegmentMethod: buildSegmentMethod(item, segmentConfigMap), UserId: item.UserId, GraphStatus: int32(status), @@ -842,7 +843,7 @@ func buildImportTask(req *knowledgebase_doc_service.ImportDocReq) (*model.Knowle DocAnalyzer: string(analyzer), CreatedAt: time.Now().UnixMilli(), UpdatedAt: time.Now().UnixMilli(), - DocInfo: string(docImportInfo), + DocInfo: db2.LongText(docImportInfo), OcrModelId: req.OcrModelId, DocPreProcess: string(preprocess), MetaData: docImportMetaData, @@ -899,7 +900,7 @@ func buildReImportTask(req *knowledgebase_doc_service.UpdateDocImportConfigReq, DocAnalyzer: string(analyzer), CreatedAt: time.Now().UnixMilli(), UpdatedAt: time.Now().UnixMilli(), - DocInfo: string(docImportInfo), + DocInfo: db2.LongText(docImportInfo), OcrModelId: docImportReq.OcrModelId, DocPreProcess: string(preprocess), MetaData: "", @@ -933,7 +934,7 @@ func buildReimportTask(req *knowledgebase_doc_service.ReImportDocReq, task *mode DocAnalyzer: task.DocAnalyzer, CreatedAt: time.Now().UnixMilli(), UpdatedAt: time.Now().UnixMilli(), - DocInfo: string(docImportInfo), + DocInfo: db2.LongText(docImportInfo), OcrModelId: task.OcrModelId, DocPreProcess: task.DocPreProcess, MetaData: "", @@ -1223,7 +1224,7 @@ func buildKnowledgeInfo(ctx context.Context, docId string) (*model.KnowledgeBase } //构造知识库图谱 - knowledgeGraph := orm.BuildKnowledgeGraph(knowledge.KnowledgeGraph) + knowledgeGraph := orm.BuildKnowledgeGraph(string(knowledge.KnowledgeGraph)) return knowledge, doc, knowledgeGraph, nil } diff --git a/internal/knowledge-service/server/grpc/knowledge_qa/knowledge_qa.go b/internal/knowledge-service/server/grpc/knowledge_qa/knowledge_qa.go index b6414fa45..f501d586a 100644 --- a/internal/knowledge-service/server/grpc/knowledge_qa/knowledge_qa.go +++ b/internal/knowledge-service/server/grpc/knowledge_qa/knowledge_qa.go @@ -16,6 +16,7 @@ import ( "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/util" "github.com/UnicomAI/wanwu/internal/knowledge-service/server/grpc/knowledge" "github.com/UnicomAI/wanwu/internal/knowledge-service/service" + "github.com/UnicomAI/wanwu/pkg/db" "github.com/UnicomAI/wanwu/pkg/log" util2 "github.com/UnicomAI/wanwu/pkg/util" wanwu_util "github.com/UnicomAI/wanwu/pkg/util" @@ -63,7 +64,7 @@ func (s *Service) GetQAImportTip(ctx context.Context, req *knowledgebase_qa_serv return &knowledgebase_qa_service.QAImportTipResp{ KnowledgeId: req.KnowledgeId, KnowledgeName: knowledge.Name, - Message: "\n" + task.ErrorMsg, + Message: string("\n" + task.ErrorMsg), UploadStatus: model.KnowledgeQAPairImportFail, }, nil } else if task.Status == model.KnowledgeQAPairImportSuccess { @@ -188,10 +189,10 @@ func (s *Service) UpdateQAPair(ctx context.Context, req *knowledgebase_qa_servic question := strings.Trim(req.Question, " ") answer := strings.Trim(req.Answer, " ") questionMD5 := util.MD5(question) - if qaPair.Question == question && qaPair.Answer == answer { + if string(qaPair.Question) == question && string(qaPair.Answer) == answer { return nil, nil } - questionOmitempty, answerOmitempty := qaPair.Question == question, qaPair.Answer == answer + questionOmitempty, answerOmitempty := string(qaPair.Question) == question, string(qaPair.Answer) == answer // 4.更新问答对 qaPair, ragParams := buildUpdateQAPairParams(knowledgeBase, question, answer, questionMD5, req.QaPairId, questionOmitempty, answerOmitempty) err = orm.UpdateKnowledgeQAPair(ctx, qaPair, ragParams) @@ -501,7 +502,7 @@ func buildQAPairImportTask(req *knowledgebase_qa_service.ImportQAPairReq) (*mode KnowledgeId: req.KnowledgeId, CreatedAt: time.Now().UnixMilli(), UpdatedAt: time.Now().UnixMilli(), - DocInfo: string(docImportInfo), + DocInfo: db.LongText(docImportInfo), Status: model.KnowledgeQAPairImportInit, UserId: req.UserId, OrgId: req.OrgId, @@ -530,11 +531,11 @@ func buildQAPairListResp(list []*model.KnowledgeQAPair, knowledge *model.Knowled retList = append(retList, &knowledgebase_qa_service.QAPairInfo{ QaPairId: item.QAPairId, KnowledgeId: item.KnowledgeId, - Question: item.Question, - Answer: item.Answer, + Question: string(item.Question), + Answer: string(item.Answer), Status: int32(item.Status), Switch: item.Switch, - ErrorMsg: item.ErrorMsg, + ErrorMsg: string(item.ErrorMsg), UploadTime: util2.Time2Str(item.CreatedAt), UserId: item.UserId, MetaDataList: buildMetaList(metaMap, item.QAPairId), @@ -558,8 +559,8 @@ func buildCreateQAPairParams(knowledgeBase *model.KnowledgeBase, question, answe qaPairs := []*model.KnowledgeQAPair{&model.KnowledgeQAPair{ QAPairId: qaPairId, KnowledgeId: knowledgeBase.KnowledgeId, - Question: question, - Answer: answer, + Question: db.LongText(question), + Answer: db.LongText(answer), Status: model.KnowledgeQAPairImportSuccess, Switch: true, QuestionMd5: questionMD5, @@ -583,8 +584,8 @@ func buildCreateQAPairParams(knowledgeBase *model.KnowledgeBase, question, answe func buildUpdateQAPairParams(knowledgeBase *model.KnowledgeBase, question, answer, questionMD5, qaPairId string, qesOmi, ansOmi bool) (*model.KnowledgeQAPair, *service.RagUpdateQAPairParams) { qaPair := &model.KnowledgeQAPair{ QAPairId: qaPairId, - Question: question, - Answer: answer, + Question: db.LongText(question), + Answer: db.LongText(answer), QuestionMd5: questionMD5, KnowledgeId: knowledgeBase.KnowledgeId, } @@ -628,11 +629,11 @@ func buildQAPairInfo(item *model.KnowledgeQAPair) *knowledgebase_qa_service.QAPa return &knowledgebase_qa_service.QAPairInfo{ QaPairId: item.QAPairId, KnowledgeId: item.KnowledgeId, - Question: item.Question, - Answer: item.Answer, + Question: string(item.Question), + Answer: string(item.Answer), UploadTime: util2.Time2Str(item.CreatedAt), Status: int32(item.Status), - ErrorMsg: item.ErrorMsg, + ErrorMsg: string(item.ErrorMsg), Switch: item.Switch, UserId: item.UserId, } diff --git a/internal/knowledge-service/service/rag_doc_service.go b/internal/knowledge-service/service/rag_doc_service.go index 93f9d98e3..b2b986450 100644 --- a/internal/knowledge-service/service/rag_doc_service.go +++ b/internal/knowledge-service/service/rag_doc_service.go @@ -214,7 +214,7 @@ func RagImportDoc(ctx context.Context, ragImportDocParams *RagImportDocParams) e Operation: "add", Type: "doc", Doc: ragImportDocParams, - }, config.GetConfig().Kafka.Topic) + }, config.GetConfig().Topic.Topic) } // RagBuildKnowledgeGraph 构建知识库图谱 @@ -223,7 +223,7 @@ func RagBuildKnowledgeGraph(ctx context.Context, ragImportDocParams *RagImportDo Operation: "add", Type: "doc", Doc: ragImportDocParams, - }, config.GetConfig().Kafka.KnowledgeGraphTopic) + }, config.GetConfig().Topic.KnowledgeGraphTopic) } // RagImportUrlDoc 导入url文档 diff --git a/internal/knowledge-service/service/rag_knowledge_service.go b/internal/knowledge-service/service/rag_knowledge_service.go index e995d4a0f..5c79bc567 100644 --- a/internal/knowledge-service/service/rag_knowledge_service.go +++ b/internal/knowledge-service/service/rag_knowledge_service.go @@ -192,7 +192,7 @@ func RagCreateKnowledgeReport(ctx context.Context, ragImportDocParams *RagImport Operation: "add", Type: "doc", Doc: ragImportDocParams, - }, config.GetConfig().Kafka.KnowledgeGraphTopic) + }, config.GetConfig().Topic.KnowledgeGraphTopic) } // RagKnowledgeUpdate rag更新知识库 diff --git a/internal/knowledge-service/task/doc_delete_task_service.go b/internal/knowledge-service/task/doc_delete_task_service.go index 1084cda11..95a71ff19 100644 --- a/internal/knowledge-service/task/doc_delete_task_service.go +++ b/internal/knowledge-service/task/doc_delete_task_service.go @@ -6,6 +6,8 @@ import ( "errors" "sync" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" @@ -13,8 +15,6 @@ import ( "github.com/UnicomAI/wanwu/internal/knowledge-service/service" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" "gorm.io/gorm" ) diff --git a/internal/knowledge-service/task/doc_import_task_service.go b/internal/knowledge-service/task/doc_import_task_service.go index 291bdd840..b1cd17123 100644 --- a/internal/knowledge-service/task/doc_import_task_service.go +++ b/internal/knowledge-service/task/doc_import_task_service.go @@ -6,14 +6,14 @@ import ( "errors" "sync" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" import_service "github.com/UnicomAI/wanwu/internal/knowledge-service/task/import-service" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" ) var docImportTask = &DocImportTask{Del: true} diff --git a/internal/knowledge-service/task/doc_re_import_task_service.go b/internal/knowledge-service/task/doc_re_import_task_service.go index 143b947a2..a4c5ce485 100644 --- a/internal/knowledge-service/task/doc_re_import_task_service.go +++ b/internal/knowledge-service/task/doc_re_import_task_service.go @@ -7,14 +7,14 @@ import ( "fmt" "sync" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/db" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" "gorm.io/gorm" ) diff --git a/internal/knowledge-service/task/doc_segment_import_task_service.go b/internal/knowledge-service/task/doc_segment_import_task_service.go index 1bdf83ad1..4cc3e28ad 100644 --- a/internal/knowledge-service/task/doc_segment_import_task_service.go +++ b/internal/knowledge-service/task/doc_segment_import_task_service.go @@ -11,14 +11,14 @@ import ( "sync" "unicode/utf8" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" "github.com/UnicomAI/wanwu/internal/knowledge-service/service" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" ) var docSegmentImportTask = &DocSegmentImportTask{Del: true} diff --git a/internal/knowledge-service/task/import-service/file_doc_import_service.go b/internal/knowledge-service/task/import-service/file_doc_import_service.go index 0811fb080..69d33f6ba 100644 --- a/internal/knowledge-service/task/import-service/file_doc_import_service.go +++ b/internal/knowledge-service/task/import-service/file_doc_import_service.go @@ -13,6 +13,7 @@ import ( "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/util" "github.com/UnicomAI/wanwu/internal/knowledge-service/service" file_extract "github.com/UnicomAI/wanwu/internal/knowledge-service/task/file-extract" + "github.com/UnicomAI/wanwu/pkg/db" "github.com/UnicomAI/wanwu/pkg/log" wanwu_util "github.com/UnicomAI/wanwu/pkg/util" ) @@ -201,7 +202,7 @@ func buildKnowledgeDoc(importTask *model.KnowledgeImportTask, checkFileResult *C UserId: importTask.UserId, OrgId: importTask.OrgId, Status: checkFileResult.Status, - ErrorMsg: checkFileResult.ErrMessage, + ErrorMsg: db.LongText(checkFileResult.ErrMessage), } } diff --git a/internal/knowledge-service/task/import-service/url_doc_import_service.go b/internal/knowledge-service/task/import-service/url_doc_import_service.go index c8398a861..f1239bfdc 100644 --- a/internal/knowledge-service/task/import-service/url_doc_import_service.go +++ b/internal/knowledge-service/task/import-service/url_doc_import_service.go @@ -8,6 +8,7 @@ import ( "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/util" + "github.com/UnicomAI/wanwu/pkg/db" "github.com/UnicomAI/wanwu/pkg/log" wanwu_util "github.com/UnicomAI/wanwu/pkg/util" ) @@ -103,7 +104,7 @@ func buildKnowledgeUrlDoc(importTask *model.KnowledgeImportTask, docInfo *CheckF FilePathMd5: util.MD5(docInfo.DocInfo.DocUrl), Name: docInfo.DocInfo.DocName, Status: docInfo.Status, - ErrorMsg: docInfo.ErrMessage, + ErrorMsg: db.LongText(docInfo.ErrMessage), FileType: UrlFileType, FileSize: fileSize, CreatedAt: time.Now().UnixMilli(), diff --git a/internal/knowledge-service/task/knowledge_delete_task_service.go b/internal/knowledge-service/task/knowledge_delete_task_service.go index 8ef0d23bd..c547d7803 100644 --- a/internal/knowledge-service/task/knowledge_delete_task_service.go +++ b/internal/knowledge-service/task/knowledge_delete_task_service.go @@ -6,14 +6,14 @@ import ( "errors" "sync" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/db" "github.com/UnicomAI/wanwu/internal/knowledge-service/service" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" "gorm.io/gorm" ) diff --git a/internal/knowledge-service/task/knowledge_doc_export_task_service.go b/internal/knowledge-service/task/knowledge_doc_export_task_service.go index 708180628..08bc80932 100644 --- a/internal/knowledge-service/task/knowledge_doc_export_task_service.go +++ b/internal/knowledge-service/task/knowledge_doc_export_task_service.go @@ -12,6 +12,8 @@ import ( "sync" "time" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" @@ -19,8 +21,6 @@ import ( "github.com/UnicomAI/wanwu/internal/knowledge-service/service" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" ) const fileTypeUrl = "url" diff --git a/internal/knowledge-service/task/knowledge_qa_delete_task_service.go b/internal/knowledge-service/task/knowledge_qa_delete_task_service.go index 4c1f61cad..bd3e570a9 100644 --- a/internal/knowledge-service/task/knowledge_qa_delete_task_service.go +++ b/internal/knowledge-service/task/knowledge_qa_delete_task_service.go @@ -6,6 +6,8 @@ import ( "errors" "sync" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" @@ -14,8 +16,6 @@ import ( "github.com/UnicomAI/wanwu/internal/knowledge-service/service" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" "gorm.io/gorm" ) diff --git a/internal/knowledge-service/task/knowledge_qa_export_task_service.go b/internal/knowledge-service/task/knowledge_qa_export_task_service.go index 34255e9d9..3b73837a7 100644 --- a/internal/knowledge-service/task/knowledge_qa_export_task_service.go +++ b/internal/knowledge-service/task/knowledge_qa_export_task_service.go @@ -11,6 +11,8 @@ import ( "sync" "time" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" @@ -19,8 +21,6 @@ import ( "github.com/UnicomAI/wanwu/internal/knowledge-service/service" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" ) const ( @@ -226,7 +226,7 @@ func exportCsvFile(ctx context.Context, knowledgeId string) (int64, int64, strin lineCount = total var records [][]string for _, qaPair := range qaPairs { - records = append(records, []string{qaPair.Question, qaPair.Answer}) + records = append(records, []string{string(qaPair.Question), string(qaPair.Answer)}) } err = writer.WriteAll(records) if err != nil { diff --git a/internal/knowledge-service/task/knowledge_qa_import_task_service.go b/internal/knowledge-service/task/knowledge_qa_import_task_service.go index 99e2e3688..616318f86 100644 --- a/internal/knowledge-service/task/knowledge_qa_import_task_service.go +++ b/internal/knowledge-service/task/knowledge_qa_import_task_service.go @@ -10,14 +10,15 @@ import ( "strings" "sync" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" "github.com/UnicomAI/wanwu/internal/knowledge-service/service" + "github.com/UnicomAI/wanwu/pkg/db" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" ) const ( @@ -307,8 +308,8 @@ func buildQAPairBatchProcessor(knowledgeBase *model.KnowledgeBase, importTask *m QAPairId: qaPairId, ImportTaskId: importTask.ImportId, KnowledgeId: knowledgeBase.KnowledgeId, - Question: question, - Answer: answer, + Question: db.LongText(question), + Answer: db.LongText(answer), Status: model.KnowledgeQAPairImportSuccess, Switch: true, QuestionMd5: questionMD5, diff --git a/internal/knowledge-service/task/knowledge_report_task_service.go b/internal/knowledge-service/task/knowledge_report_task_service.go index 558432513..13a709bb3 100644 --- a/internal/knowledge-service/task/knowledge_report_task_service.go +++ b/internal/knowledge-service/task/knowledge_report_task_service.go @@ -7,14 +7,14 @@ import ( "fmt" "sync" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" "github.com/UnicomAI/wanwu/internal/knowledge-service/service" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" ) var knowledgeReportTask = &KnowledgeReportTask{Del: true} diff --git a/internal/knowledge-service/task/qa_pair_delete_task_service.go b/internal/knowledge-service/task/qa_pair_delete_task_service.go index e7130b2b5..8bdd25b09 100644 --- a/internal/knowledge-service/task/qa_pair_delete_task_service.go +++ b/internal/knowledge-service/task/qa_pair_delete_task_service.go @@ -6,6 +6,8 @@ import ( "errors" "sync" + async "github.com/UnicomAI/wanwu/async" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/model" "github.com/UnicomAI/wanwu/internal/knowledge-service/client/orm" async_task_pkg "github.com/UnicomAI/wanwu/internal/knowledge-service/pkg/async-task" @@ -13,8 +15,6 @@ import ( "github.com/UnicomAI/wanwu/internal/knowledge-service/service" "github.com/UnicomAI/wanwu/pkg/log" "github.com/UnicomAI/wanwu/pkg/util" - async "github.com/gromitlee/go-async" - "github.com/gromitlee/go-async/pkg/async/async_task" "gorm.io/gorm" ) diff --git a/internal/knowledge-service/task/report.go b/internal/knowledge-service/task/report.go index 427e0efa6..32fad10be 100644 --- a/internal/knowledge-service/task/report.go +++ b/internal/knowledge-service/task/report.go @@ -4,7 +4,7 @@ import ( "context" "sync" - "github.com/gromitlee/go-async/pkg/async/async_task" + "github.com/UnicomAI/wanwu/async/pkg/async/async_task" ) // report impl IReport diff --git a/internal/mcp-service/client/model/builtin_tool.go b/internal/mcp-service/client/model/builtin_tool.go index 8dc707fc0..d2fa86502 100644 --- a/internal/mcp-service/client/model/builtin_tool.go +++ b/internal/mcp-service/client/model/builtin_tool.go @@ -1,12 +1,14 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + // BuiltinTool 自定义工具 type BuiltinTool struct { - ID uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id'"` - ToolSquareId string `gorm:"column:tool_square_id;index:idx_custom_tool_square_id;not null;comment:'自定义工具id'"` - AuthJSON string `gorm:"column:auth_json;type:longtext;comment:'鉴权json'"` - UserID string `gorm:"column:user_id;index:idx_user_id_name,priority:1;type:varchar(64);not null;comment:'用户id'"` - OrgID string `gorm:"column:org_id;type:varchar(64);not null;comment:'组织id'"` - CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` - UpdatedAt int64 `gorm:"autoUpdateTime:milli;comment:更新时间"` + ID uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id'"` + ToolSquareId string `gorm:"column:tool_square_id;index:idx_custom_tool_square_id;not null;comment:'自定义工具id'"` + AuthJSON db.LongText `gorm:"column:auth_json;comment:'鉴权json'"` + UserID string `gorm:"column:user_id;index:idx_user_id_name,priority:1;type:varchar(64);not null;comment:'用户id'"` + OrgID string `gorm:"column:org_id;type:varchar(64);not null;comment:'组织id'"` + CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` + UpdatedAt int64 `gorm:"autoUpdateTime:milli;comment:更新时间"` } diff --git a/internal/mcp-service/client/model/custom_tool.go b/internal/mcp-service/client/model/custom_tool.go index 39607dc5f..1bc4337a5 100644 --- a/internal/mcp-service/client/model/custom_tool.go +++ b/internal/mcp-service/client/model/custom_tool.go @@ -1,25 +1,27 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + const ( ApiAuthNone = "none" ) // CustomTool 自定义工具 type CustomTool struct { - ID uint32 `gorm:"primary_key"` - ToolSquareId string `gorm:"index:idx_custom_tool_square_id"` - Name string `gorm:"column:name;type:varchar(255);comment:'自定义工具名称'"` - AvatarPath string `gorm:"column:avatar_path;comment:'自定义工具头像'"` - Description string `gorm:"column:description;type:longtext;comment:'自定义工具描述'"` - Schema string `gorm:"column:schema;type:longtext;comment:'schema配置'"` - PrivacyPolicy string `gorm:"column:privacy_policy;type:longtext;comment:'隐私政策'"` - Type string `gorm:"column:type;type:varchar(255);comment:'apiAuth认证类型(none/apiKey)'"` // DEPRECATED - APIKey string `gorm:"column:api_key;type:varchar(255);comment:'api_key,0.2.6作为内置工具专属'"` - AuthType string `gorm:"column:auth_type;type:varchar(255);comment:'authType(basic/bearer/custom)'"` // DEPRECATED - CustomHeaderName string `gorm:"column:custom_header_name;type:varchar(255);comment:'自定义header名称'"` // DEPRECATED - AuthJSON string `gorm:"column:auth_json;type:longtext;comment:'鉴权json'"` - UserID string `gorm:"column:user_id;index:idx_custom_tool_user_id;comment:'用户id'"` - OrgID string `gorm:"column:org_id;index:idx_custom_tool_org_id;comment:'组织id'"` - CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` - UpdatedAt int64 `gorm:"autoUpdateTime:milli;comment:更新时间"` + ID uint32 `gorm:"primary_key"` + ToolSquareId string `gorm:"index:idx_custom_tool_square_id"` + Name string `gorm:"column:name;type:varchar(255);comment:'自定义工具名称'"` + AvatarPath string `gorm:"column:avatar_path;comment:'自定义工具头像'"` + Description db.LongText `gorm:"column:description;comment:'自定义工具描述'"` + Schema db.LongText `gorm:"column:schema;comment:'schema配置'"` + PrivacyPolicy db.LongText `gorm:"column:privacy_policy;comment:'隐私政策'"` + Type string `gorm:"column:type;type:varchar(255);comment:'apiAuth认证类型(none/apiKey)'"` // DEPRECATED + APIKey string `gorm:"column:api_key;type:varchar(255);comment:'api_key,0.2.6作为内置工具专属'"` + AuthType string `gorm:"column:auth_type;type:varchar(255);comment:'authType(basic/bearer/custom)'"` // DEPRECATED + CustomHeaderName string `gorm:"column:custom_header_name;type:varchar(255);comment:'自定义header名称'"` // DEPRECATED + AuthJSON db.LongText `gorm:"column:auth_json;comment:'鉴权json'"` + UserID string `gorm:"column:user_id;index:idx_custom_tool_user_id;comment:'用户id'"` + OrgID string `gorm:"column:org_id;index:idx_custom_tool_org_id;comment:'组织id'"` + CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` + UpdatedAt int64 `gorm:"autoUpdateTime:milli;comment:更新时间"` } diff --git a/internal/mcp-service/client/model/mcp_server.go b/internal/mcp-service/client/model/mcp_server.go index 839ba048a..0513b7b25 100644 --- a/internal/mcp-service/client/model/mcp_server.go +++ b/internal/mcp-service/client/model/mcp_server.go @@ -1,13 +1,15 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + type MCPServer struct { - ID uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id'"` - MCPServerID string `gorm:"uniqueIndex:idx_unique_mcp_server_id;column:mcp_server_id;type:varchar(255);not null;comment:'mcp server id'"` - Name string `gorm:"column:name;index:idx_user_id_name,priority:2;type:varchar(255);comment:'mcp server名称'"` - Description string `gorm:"column:description;type:longtext;comment:'mcp server描述'"` - AvatarPath string `gorm:"column:avatar_path;comment:mcp server头像"` - UserID string `gorm:"column:user_id;index:idx_user_id_name,priority:1;type:varchar(64);not null;comment:'用户id'"` - OrgID string `gorm:"column:org_id;type:varchar(64);not null;comment:'组织id'"` - CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` - UpdatedAt int64 `gorm:"autoUpdateTime:milli;comment:更新时间"` + ID uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id'"` + MCPServerID string `gorm:"uniqueIndex:idx_unique_mcp_server_id;column:mcp_server_id;type:varchar(255);not null;comment:'mcp server id'"` + Name string `gorm:"column:name;index:idx_user_id_name,priority:2;type:varchar(255);comment:'mcp server名称'"` + Description db.LongText `gorm:"column:description;comment:'mcp server描述'"` + AvatarPath string `gorm:"column:avatar_path;comment:mcp server头像"` + UserID string `gorm:"column:user_id;index:idx_user_id_name,priority:1;type:varchar(64);not null;comment:'用户id'"` + OrgID string `gorm:"column:org_id;type:varchar(64);not null;comment:'组织id'"` + CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` + UpdatedAt int64 `gorm:"autoUpdateTime:milli;comment:更新时间"` } diff --git a/internal/mcp-service/client/model/mcp_server_tool.go b/internal/mcp-service/client/model/mcp_server_tool.go index 9f65fd527..625ccd30a 100644 --- a/internal/mcp-service/client/model/mcp_server_tool.go +++ b/internal/mcp-service/client/model/mcp_server_tool.go @@ -1,21 +1,23 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + type MCPServerTool struct { - ID uint32 `gorm:"column:id;primary_key;type:bigint(20) auto_increment;not null;comment:'id'"` - MCPServerToolId string `gorm:"uniqueIndex:idx_unique_mcp_server_tool_id;column:mcp_server_tool_id;type:varchar(255);not null;comment:'mcp server tool id'"` - McpServerId string `gorm:"uniqueIndex:idx_unique_mcp_server_id_name,priority:1;column:mcp_server_id;index:idx_mcp_server_id;type:varchar(255);comment:'mcp server id'"` - AppToolId string `gorm:"column:app_tool_id;type:varchar(255);comment:'应用或工具id'"` - Type string `gorm:"column:type;type:varchar(255);comment:'mcp server tool类型'"` - AppToolName string `gorm:"column:app_tool_name;type:varchar(255);comment:'应用或工具名称'"` - Name string `gorm:"uniqueIndex:idx_unique_mcp_server_id_name,priority:2;column:name;type:varchar(255);comment:'mcp server tool名称'"` - Description string `gorm:"column:description;type:longtext;comment:'mcp server tool描述'"` - Schema string `gorm:"column:schema;type:longtext;comment:'openapi schema'"` - AuthType string `gorm:"column:auth_type;type:varchar(255);comment:'鉴权类型'"` - AuthIn string `gorm:"column:auth_in;type:varchar(255);comment:'鉴权位置'"` - AuthName string `gorm:"column:auth_name;type:varchar(255);comment:'鉴权名称'"` - AuthValue string `gorm:"column:auth_value;type:varchar(255);comment:'鉴权值'"` - UserID string `gorm:"column:user_id;type:varchar(64);not null;comment:'用户id'"` - OrgID string `gorm:"column:org_id;type:varchar(64);not null;comment:'组织id'"` - CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` - UpdatedAt int64 `gorm:"autoUpdateTime:milli;comment:更新时间"` + ID uint32 `gorm:"column:id;primary_key;type:bigint auto_increment;not null;comment:'id'"` + MCPServerToolId string `gorm:"uniqueIndex:idx_unique_mcp_server_tool_id;column:mcp_server_tool_id;type:varchar(255);not null;comment:'mcp server tool id'"` + McpServerId string `gorm:"uniqueIndex:idx_unique_mcp_server_id_name,priority:1;column:mcp_server_id;index:idx_mcp_server_id;type:varchar(255);comment:'mcp server id'"` + AppToolId string `gorm:"column:app_tool_id;type:varchar(255);comment:'应用或工具id'"` + Type string `gorm:"column:type;type:varchar(255);comment:'mcp server tool类型'"` + AppToolName string `gorm:"column:app_tool_name;type:varchar(255);comment:'应用或工具名称'"` + Name string `gorm:"uniqueIndex:idx_unique_mcp_server_id_name,priority:2;column:name;type:varchar(255);comment:'mcp server tool名称'"` + Description db.LongText `gorm:"column:description;comment:'mcp server tool描述'"` + Schema db.LongText `gorm:"column:schema;comment:'openapi schema'"` + AuthType string `gorm:"column:auth_type;type:varchar(255);comment:'鉴权类型'"` + AuthIn string `gorm:"column:auth_in;type:varchar(255);comment:'鉴权位置'"` + AuthName string `gorm:"column:auth_name;type:varchar(255);comment:'鉴权名称'"` + AuthValue string `gorm:"column:auth_value;type:varchar(255);comment:'鉴权值'"` + UserID string `gorm:"column:user_id;type:varchar(64);not null;comment:'用户id'"` + OrgID string `gorm:"column:org_id;type:varchar(64);not null;comment:'组织id'"` + CreatedAt int64 `gorm:"autoCreateTime:milli;comment:创建时间"` + UpdatedAt int64 `gorm:"autoUpdateTime:milli;comment:更新时间"` } diff --git a/internal/mcp-service/server/grpc/mcp/mcp.go b/internal/mcp-service/server/grpc/mcp/mcp.go index 4707498f5..d7ac580a4 100644 --- a/internal/mcp-service/server/grpc/mcp/mcp.go +++ b/internal/mcp-service/server/grpc/mcp/mcp.go @@ -135,7 +135,7 @@ func (s *Service) GetMCPByMCPIdList(ctx context.Context, req *mcp_service.GetMCP serverToolInfos = append(serverToolInfos, &mcp_service.MCPServerInfo{ McpServerId: info.MCPServerID, Name: info.Name, - Desc: info.Description, + Desc: string(info.Description), AvatarPath: info.AvatarPath, ToolNum: toolNum, }) diff --git a/internal/mcp-service/server/grpc/mcp/mcp_server.go b/internal/mcp-service/server/grpc/mcp/mcp_server.go index a35260f2f..24b88766c 100644 --- a/internal/mcp-service/server/grpc/mcp/mcp_server.go +++ b/internal/mcp-service/server/grpc/mcp/mcp_server.go @@ -12,6 +12,7 @@ import ( "github.com/UnicomAI/wanwu/internal/mcp-service/client/model" "github.com/UnicomAI/wanwu/internal/mcp-service/config" "github.com/UnicomAI/wanwu/pkg/constant" + "github.com/UnicomAI/wanwu/pkg/db" "github.com/UnicomAI/wanwu/pkg/util" "google.golang.org/protobuf/types/known/emptypb" ) @@ -26,7 +27,7 @@ func (s *Service) CreateMCPServer(ctx context.Context, req *mcp_service.CreateMC mcpServer = &model.MCPServer{ MCPServerID: mcpServerId, Name: req.Name, - Description: req.Desc, + Description: db.LongText(req.Desc), AvatarPath: req.AvatarPath, UserID: req.Identity.UserId, OrgID: req.Identity.OrgId, @@ -42,7 +43,7 @@ func (s *Service) UpdateMCPServer(ctx context.Context, req *mcp_service.UpdateMC mcpServer := &model.MCPServer{ MCPServerID: req.McpServerId, Name: req.Name, - Description: req.Desc, + Description: db.LongText(req.Desc), AvatarPath: req.AvatarPath, } err := s.cli.UpdateMCPServer(ctx, mcpServer) @@ -61,7 +62,7 @@ func (s *Service) GetMCPServer(ctx context.Context, req *mcp_service.GetMCPServe return &mcp_service.MCPServerInfo{ Name: info.Name, McpServerId: info.MCPServerID, - Desc: info.Description, + Desc: string(info.Description), AvatarPath: info.AvatarPath, SseUrl: sseUrl, SseExample: sseExample, @@ -93,7 +94,7 @@ func (s *Service) GetMCPServerList(ctx context.Context, req *mcp_service.GetMCPS list = append(list, &mcp_service.MCPServerInfo{ McpServerId: info.MCPServerID, Name: info.Name, - Desc: info.Description, + Desc: string(info.Description), AvatarPath: info.AvatarPath, ToolNum: toolNum, SseUrl: sseUrl, @@ -116,11 +117,11 @@ func (s *Service) GetMCPServerTool(ctx context.Context, req *mcp_service.GetMCPS McpServerToolId: info.MCPServerToolId, McpServerId: info.McpServerId, Name: info.Name, - Desc: info.Description, + Desc: string(info.Description), Type: info.Type, AppToolId: info.AppToolId, AppToolName: info.AppToolName, - Schema: info.Schema, + Schema: string(info.Schema), ApiAuth: &common.ApiAuth{ AuthType: info.AuthType, AuthIn: info.AuthIn, @@ -137,11 +138,11 @@ func (s *Service) CreateMCPServerTool(ctx context.Context, req *mcp_service.Crea MCPServerToolId: util.GenUUID(), McpServerId: info.McpServerId, Name: info.Name, - Description: info.Desc, + Description: db.LongText(info.Desc), Type: info.Type, AppToolId: info.AppToolId, AppToolName: info.AppToolName, - Schema: info.Schema, + Schema: db.LongText(info.Schema), AuthType: info.ApiAuth.AuthType, AuthIn: info.ApiAuth.AuthIn, AuthName: info.ApiAuth.AuthName, @@ -159,8 +160,8 @@ func (s *Service) UpdateMCPServerTool(ctx context.Context, req *mcp_service.Upda mcpServerTool := &model.MCPServerTool{ MCPServerToolId: req.McpServerToolId, Name: req.Name, - Description: req.Desc, - Schema: req.Schema, + Description: db.LongText(req.Desc), + Schema: db.LongText(req.Schema), } err := s.cli.UpdateMCPServerTool(ctx, mcpServerTool) if err != nil { @@ -188,11 +189,11 @@ func (s *Service) GetMCPServerToolList(ctx context.Context, req *mcp_service.Get McpServerToolId: info.MCPServerToolId, McpServerId: info.McpServerId, Name: info.Name, - Desc: info.Description, + Desc: string(info.Description), Type: info.Type, AppToolId: info.AppToolId, AppToolName: info.AppToolName, - Schema: info.Schema, + Schema: string(info.Schema), ApiAuth: &common.ApiAuth{ AuthType: info.AuthType, AuthIn: info.AuthIn, diff --git a/internal/mcp-service/server/grpc/mcp/tool.go b/internal/mcp-service/server/grpc/mcp/tool.go index d22d47233..8561ea034 100644 --- a/internal/mcp-service/server/grpc/mcp/tool.go +++ b/internal/mcp-service/server/grpc/mcp/tool.go @@ -40,7 +40,7 @@ func (s *Service) GetToolByIdList(ctx context.Context, req *mcp_service.GetToolB list = append(list, &mcp_service.GetCustomToolItem{ CustomToolId: util.Int2Str(info.ID), Name: info.Name, - Description: info.Description, + Description: string(info.Description), AvatarPath: info.AvatarPath, }) } @@ -75,7 +75,7 @@ func (s *Service) GetToolSelect(ctx context.Context, req *mcp_service.GetToolSel list = append(list, &mcp_service.GetToolItem{ ToolId: util.Int2Str(info.ID), ToolName: info.Name, - Desc: info.Description, + Desc: string(info.Description), ApiKey: apiAuth.ApiKeyValue, ToolType: constant.ToolTypeCustom, NeedApiKeyInput: needApiKeyInput, diff --git a/internal/mcp-service/server/grpc/mcp/tool_builtin.go b/internal/mcp-service/server/grpc/mcp/tool_builtin.go index 3003cf568..44577967f 100644 --- a/internal/mcp-service/server/grpc/mcp/tool_builtin.go +++ b/internal/mcp-service/server/grpc/mcp/tool_builtin.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/ThinkInAIXYZ/go-mcp/protocol" + "github.com/UnicomAI/wanwu/pkg/db" "github.com/UnicomAI/wanwu/api/proto/common" errs "github.com/UnicomAI/wanwu/api/proto/err-code" @@ -72,7 +73,7 @@ func (s *Service) UpsertBuiltinToolAPIKey(ctx context.Context, req *mcp_service. }) if info != nil { // update - info.AuthJSON = string(apiAuthBytes) + info.AuthJSON = db.LongText(apiAuthBytes) if err := s.cli.UpdateBuiltinTool(ctx, info); err != nil { return nil, errStatus(errs.Code_MCPUpdateBuiltinToolErr, err) } @@ -81,7 +82,7 @@ func (s *Service) UpsertBuiltinToolAPIKey(ctx context.Context, req *mcp_service. // create if err := s.cli.CreateBuiltinTool(ctx, &model.BuiltinTool{ ToolSquareId: req.ToolSquareId, - AuthJSON: string(apiAuthBytes), + AuthJSON: db.LongText(apiAuthBytes), UserID: req.Identity.UserId, OrgID: req.Identity.OrgId, }); err != nil { diff --git a/internal/mcp-service/server/grpc/mcp/tool_custom.go b/internal/mcp-service/server/grpc/mcp/tool_custom.go index 827bfed9c..177fe2850 100644 --- a/internal/mcp-service/server/grpc/mcp/tool_custom.go +++ b/internal/mcp-service/server/grpc/mcp/tool_custom.go @@ -8,6 +8,7 @@ import ( errs "github.com/UnicomAI/wanwu/api/proto/err-code" mcp_service "github.com/UnicomAI/wanwu/api/proto/mcp-service" "github.com/UnicomAI/wanwu/internal/mcp-service/client/model" + "github.com/UnicomAI/wanwu/pkg/db" grpc_util "github.com/UnicomAI/wanwu/pkg/grpc-util" "github.com/UnicomAI/wanwu/pkg/util" "google.golang.org/protobuf/types/known/emptypb" @@ -27,10 +28,10 @@ func (s *Service) CreateCustomTool(ctx context.Context, req *mcp_service.CreateC if err := s.cli.CreateCustomTool(ctx, &model.CustomTool{ AvatarPath: req.AvatarPath, Name: req.Name, - Description: req.Description, - Schema: req.Schema, - PrivacyPolicy: req.PrivacyPolicy, - AuthJSON: string(apiAuthBytes), + Description: db.LongText(req.Description), + Schema: db.LongText(req.Schema), + PrivacyPolicy: db.LongText(req.PrivacyPolicy), + AuthJSON: db.LongText(apiAuthBytes), UserID: req.Identity.UserId, OrgID: req.Identity.OrgId, }); err != nil { @@ -58,9 +59,9 @@ func (s *Service) GetCustomToolInfo(ctx context.Context, req *mcp_service.GetCus CustomToolId: util.Int2Str(info.ID), AvatarPath: info.AvatarPath, Name: info.Name, - Description: info.Description, - Schema: info.Schema, - PrivacyPolicy: info.PrivacyPolicy, + Description: string(info.Description), + Schema: string(info.Schema), + PrivacyPolicy: string(info.PrivacyPolicy), ApiAuth: apiAuth, }, nil } @@ -78,7 +79,7 @@ func (s *Service) GetCustomToolList(ctx context.Context, req *mcp_service.GetCus list = append(list, &mcp_service.GetCustomToolItem{ CustomToolId: util.Int2Str(info.ID), Name: info.Name, - Description: info.Description, + Description: string(info.Description), AvatarPath: info.AvatarPath, }) } @@ -108,7 +109,7 @@ func (s *Service) GetCustomToolByCustomToolIdList(ctx context.Context, req *mcp_ list = append(list, &mcp_service.GetCustomToolItem{ CustomToolId: util.Int2Str(info.ID), Name: info.Name, - Description: info.Description, + Description: string(info.Description), }) } return &mcp_service.GetCustomToolListResp{ @@ -131,10 +132,10 @@ func (s *Service) UpdateCustomTool(ctx context.Context, req *mcp_service.UpdateC ID: util.MustU32(req.CustomToolId), AvatarPath: req.AvatarPath, Name: req.Name, - Description: req.Description, - Schema: req.Schema, - PrivacyPolicy: req.PrivacyPolicy, - AuthJSON: string(apiAuthBytes), + Description: db.LongText(req.Description), + Schema: db.LongText(req.Schema), + PrivacyPolicy: db.LongText(req.PrivacyPolicy), + AuthJSON: db.LongText(apiAuthBytes), }); err != nil { return nil, errStatus(errs.Code_MCPUpdateCustomToolErr, err) } diff --git a/internal/model-service/client/model/model_imported.go b/internal/model-service/client/model/model_imported.go index a934d52b1..c8f1e8c17 100644 --- a/internal/model-service/client/model/model_imported.go +++ b/internal/model-service/client/model/model_imported.go @@ -1,16 +1,18 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + type ModelImported struct { - ID uint32 `gorm:"primary_key;auto_increment;not null;"` - UUID string `gorm:"column:uuid;type:varchar(255);uniqueIndex:idx_unique_uuid;comment:模型uuid"` - Provider string `gorm:"column:provider;index:idx_model_imported_provider_type_model,priority:1;type:varchar(100);comment:模型供应商"` - ModelType string `gorm:"column:model_type;index:idx_model_imported_provider_type_model,priority:2;type:varchar(100);comment:模型类型"` - Model string `gorm:"column:model;index:idx_model_imported_provider_type_model,priority:3;type:varchar(100);comment:模型名称"` - DisplayName string `gorm:"column:display_name;idx:idx_model_imported_model_display_name;type:varchar(100);comment:模型显示名称"` - ModelIconPath string `gorm:"column:model_icon_path;type:varchar(512);comment:模型图标路径"` - IsActive bool `gorm:"column:is_active;type:tinyint(1);default:true;comment:模型是否启用"` - ProviderConfig string `gorm:"column:provider_config;type:longtext;comment:某供应商下的模型配置"` - ModelDesc string `gorm:"column:model_desc;type:longtext;comment:模型描述"` - PublishDate string `gorm:"column:publish_date;type:varchar(100);comment:模型发布时间"` + ID uint32 `gorm:"primary_key;auto_increment;not null;"` + UUID string `gorm:"column:uuid;type:varchar(255);uniqueIndex:idx_unique_uuid;comment:模型uuid"` + Provider string `gorm:"column:provider;index:idx_model_imported_provider_type_model,priority:1;type:varchar(100);comment:模型供应商"` + ModelType string `gorm:"column:model_type;index:idx_model_imported_provider_type_model,priority:2;type:varchar(100);comment:模型类型"` + Model string `gorm:"column:model;index:idx_model_imported_provider_type_model,priority:3;type:varchar(100);comment:模型名称"` + DisplayName string `gorm:"column:display_name;idx:idx_model_imported_model_display_name;type:varchar(100);comment:模型显示名称"` + ModelIconPath string `gorm:"column:model_icon_path;type:varchar(512);comment:模型图标路径"` + IsActive bool `gorm:"column:is_active;default:true;comment:模型是否启用"` + ProviderConfig db.LongText `gorm:"column:provider_config;comment:某供应商下的模型配置"` + ModelDesc db.LongText `gorm:"column:model_desc;comment:模型描述"` + PublishDate string `gorm:"column:publish_date;type:varchar(100);comment:模型发布时间"` PublicModel } diff --git a/internal/model-service/server/grpc/model/model.go b/internal/model-service/server/grpc/model/model.go index 9e9d7fd7e..260d50b2d 100644 --- a/internal/model-service/server/grpc/model/model.go +++ b/internal/model-service/server/grpc/model/model.go @@ -6,6 +6,7 @@ import ( errs "github.com/UnicomAI/wanwu/api/proto/err-code" model_service "github.com/UnicomAI/wanwu/api/proto/model-service" "github.com/UnicomAI/wanwu/internal/model-service/client/model" + "github.com/UnicomAI/wanwu/pkg/db" "github.com/UnicomAI/wanwu/pkg/util" "google.golang.org/protobuf/types/known/emptypb" ) @@ -20,12 +21,12 @@ func (s *Service) ImportModel(ctx context.Context, req *model_service.ModelInfo) ModelIconPath: req.ModelIconPath, IsActive: req.IsActive, PublishDate: req.PublishDate, - ProviderConfig: req.ProviderConfig, + ProviderConfig: db.LongText(req.ProviderConfig), PublicModel: model.PublicModel{ OrgID: req.OrgId, UserID: req.UserId, }, - ModelDesc: req.ModelDesc, + ModelDesc: db.LongText(req.ModelDesc), }); err != nil { return nil, errStatus(errs.Code_ModelImportedModel, err) } @@ -41,8 +42,8 @@ func (s *Service) UpdateModel(ctx context.Context, req *model_service.ModelInfo) DisplayName: req.DisplayName, ModelIconPath: req.ModelIconPath, PublishDate: req.PublishDate, - ProviderConfig: req.ProviderConfig, - ModelDesc: req.ModelDesc, + ProviderConfig: db.LongText(req.ProviderConfig), + ModelDesc: db.LongText(req.ModelDesc), PublicModel: model.PublicModel{ OrgID: req.OrgId, UserID: req.UserId, @@ -156,12 +157,12 @@ func toModelInfo(modelInfo *model.ModelImported) *model_service.ModelInfo { ModelIconPath: modelInfo.ModelIconPath, IsActive: modelInfo.IsActive, PublishDate: modelInfo.PublishDate, - ProviderConfig: modelInfo.ProviderConfig, + ProviderConfig: string(modelInfo.ProviderConfig), UserId: modelInfo.UserID, OrgId: modelInfo.OrgID, CreatedAt: modelInfo.CreatedAt, UpdatedAt: modelInfo.UpdatedAt, - ModelDesc: modelInfo.ModelDesc, + ModelDesc: string(modelInfo.ModelDesc), } } diff --git a/internal/operate-service/client/model/system_custom.go b/internal/operate-service/client/model/system_custom.go index 300880ec5..764773328 100644 --- a/internal/operate-service/client/model/system_custom.go +++ b/internal/operate-service/client/model/system_custom.go @@ -1,5 +1,7 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + type SystemCustom struct { ID uint32 `gorm:"primary_key"` CreatedAt int64 `gorm:"autoCreateTime:milli"` @@ -11,5 +13,5 @@ type SystemCustom struct { // Key Key string `gorm:"index:idx_system_custom_key"` // Value - Value string `gorm:"type:longtext"` + Value db.LongText `gorm:"column:value"` } diff --git a/internal/operate-service/client/orm/system_custom.go b/internal/operate-service/client/orm/system_custom.go index 9abd3352b..1674f90ea 100644 --- a/internal/operate-service/client/orm/system_custom.go +++ b/internal/operate-service/client/orm/system_custom.go @@ -8,8 +8,10 @@ import ( err_code "github.com/UnicomAI/wanwu/api/proto/err-code" "github.com/UnicomAI/wanwu/internal/operate-service/client/model" "github.com/UnicomAI/wanwu/internal/operate-service/client/orm/sqlopt" + "github.com/UnicomAI/wanwu/pkg/db" "github.com/UnicomAI/wanwu/pkg/log" "gorm.io/gorm" + "gorm.io/gorm/clause" ) func (c *Client) CreateSystemCustom(ctx context.Context, orgID, userID string, key SystemCustomKey, mode SystemCustomMode, custom SystemCustom) *err_code.Status { @@ -22,7 +24,7 @@ func (c *Client) CreateSystemCustom(ctx context.Context, orgID, userID string, k OrgID: orgID, UserID: userID, Key: key2WithModeKey(key, mode), - Value: mergeCustomFields(key, model.SystemCustom{}, custom), + Value: db.LongText(mergeCustomFields(key, model.SystemCustom{}, custom)), }).Error; err != nil { return toErrStatus("ope_system_custom_create", string(key), err.Error()) } @@ -46,8 +48,12 @@ func (c *Client) GetSystemCustom(ctx context.Context, mode SystemCustomMode) (*S } var records []model.SystemCustom + expr := clause.Expr{ + SQL: "? IN (?)", + Vars: []interface{}{clause.Column{Name: "key"}, keys}, + } if err := c.db.WithContext(ctx). - Where("`key` IN (?)", keys). + Where(expr). Find(&records).Error; err != nil { return nil, toErrStatus("ope_system_custom_get", err.Error()) } diff --git a/internal/rag-service/client/model/rag.go b/internal/rag-service/client/model/rag.go index 7090229a4..8e4a517fc 100644 --- a/internal/rag-service/client/model/rag.go +++ b/internal/rag-service/client/model/rag.go @@ -1,5 +1,7 @@ package model +import "github.com/UnicomAI/wanwu/pkg/db" + const ( MatchTypeDefault = "mix" KnowledgePriorityDefault = 1 @@ -11,7 +13,7 @@ const ( ) type RagInfo struct { - ID int64 `json:"id" gorm:"primaryKey;type:bigint(20);autoIncrement"` + ID int64 `json:"id" gorm:"primaryKey;type:bigint auto_increment;not null;"` RagID string `json:"ragId" gorm:"uniqueIndex:idx_unique_rag_id;column:rag_id;type:varchar(255);comment:ragId"` // 使用嵌入结构体(将字段直接映射到主表) @@ -20,7 +22,7 @@ type RagInfo struct { RerankConfig AppModelConfig `gorm:"embedded;embeddedPrefix:rerank_"` QARerankConfig AppModelConfig `gorm:"embedded;embeddedPrefix:qa_rerank_"` KnowledgeBaseConfig KnowledgeBaseConfig `gorm:"embedded;embeddedPrefix:kb_"` - QAKnowledgebaseConfig string `gorm:"column:qa_knowledgebase_config;type:longtext;comment:问答库配置"` + QAKnowledgebaseConfig db.LongText `gorm:"column:qa_knowledgebase_config;comment:问答库配置"` SensitiveConfig SensitiveConfig `gorm:"embedded;embeddedPrefix:sensitive_"` PublicModel } @@ -41,27 +43,27 @@ type AppModelConfig struct { type KnowledgeBaseConfig struct { KnowId string `json:"knowId" gorm:"column:know_id;type:text;comment:知识库ID"` - MaxHistory int64 `json:"maxHistory" gorm:"column:max_history;type:bigint(20);comment:最大历史记录"` - Threshold float64 `json:"threshold" gorm:"column:threshold;type:float(10,2);comment:阈值"` - TopK int64 `json:"topK" gorm:"column:top_k;type:bigint(20);comment:TopK"` + MaxHistory int64 `json:"maxHistory" gorm:"column:max_history;type:bigint;comment:最大历史记录"` + Threshold float64 `json:"threshold" gorm:"column:threshold;type:decimal(10,2);comment:阈值"` + TopK int64 `json:"topK" gorm:"column:top_k;type:bigint;comment:TopK"` MatchType string `json:"matchType" gorm:"column:match_type;type:varchar(32);not null;default:'';comment:matchType:vector(向量检索)、text(文本检索)、mix(混合检索:向量+文本)"` - PriorityMatch int32 `json:"priorityMatch" gorm:"column:priority_match;type:tinyint(1);not null;default:0;comment:权重匹配,只有在混合检索模式下,选择权重设置后,这个才设置为1"` - SemanticsPriority float64 `json:"semanticsPriority" gorm:"column:semantics_priority;type:float(10,2);not null;default:0;comment:语义权重"` - KeywordPriority float64 `json:"keywordPriority" gorm:"column:keyword_priority;type:float(10,2);not null;default:0;comment:关键词权重"` + PriorityMatch int32 `json:"priorityMatch" gorm:"column:priority_match;not null;default:0;comment:权重匹配,只有在混合检索模式下,选择权重设置后,这个才设置为1"` + SemanticsPriority float64 `json:"semanticsPriority" gorm:"column:semantics_priority;type:decimal(10,2);not null;default:0;comment:语义权重"` + KeywordPriority float64 `json:"keywordPriority" gorm:"column:keyword_priority;type:decimal(10,2);not null;default:0;comment:关键词权重"` TermWeight float64 `json:"term_weight" gorm:"column:term_weight;type:decimal(10,2);not null;default:1;comment:关键词系数,默认为1"` - TermWeightEnable bool `json:"term_weight_enable" gorm:"column:term_weight_enable;type:tinyint(1);not null;default:false;comment:是否启用关键词系数"` + TermWeightEnable bool `json:"term_weight_enable" gorm:"column:term_weight_enable;not null;default:false;comment:是否启用关键词系数"` MetaParams string `json:"metaParams" gorm:"column:meta_params;type:text;comment:元数据参数"` - UseGraph bool `json:"use_graph" gorm:"column:use_graph;type:tinyint(1);not null;default:false;comment:是否使用知识图谱"` + UseGraph bool `json:"use_graph" gorm:"column:use_graph;not null;default:false;comment:是否使用知识图谱"` } type SensitiveConfig struct { - Enable bool `json:"enable" gorm:"column:enable;type:tinyint(1);comment:是否启用安全护栏"` + Enable bool `json:"enable" gorm:"column:enable;comment:是否启用安全护栏"` TableIds string `json:"tableIds" gorm:"column:table_ids;type:text;comment:敏感词表ID列表"` } type PublicModel struct { - CreatedAt int64 `json:"createdAt" gorm:"autoCreateTime:milli;index:created_at;column:created_at;type:bigint(20);comment:创建时间"` - UpdatedAt int64 `json:"updatedAt" gorm:"autoUpdateTime:milli;index:updated_at;column:updated_at;type:bigint(20);comment:更新时间"` + CreatedAt int64 `json:"createdAt" gorm:"autoCreateTime:milli;index:created_at;column:created_at;type:bigint;comment:创建时间"` + UpdatedAt int64 `json:"updatedAt" gorm:"autoUpdateTime:milli;index:updated_at;column:updated_at;type:bigint;comment:更新时间"` OrgID string `gorm:"index:org_id;column:org_id;type:varchar(255);comment:组织ID" json:"orgId"` UserID string `gorm:"index:user_id;column:user_id;type:varchar(255);comment:用户ID" json:"userId"` } diff --git a/internal/rag-service/server/grpc/rag/service.go b/internal/rag-service/server/grpc/rag/service.go index f460c796f..9467aba23 100644 --- a/internal/rag-service/server/grpc/rag/service.go +++ b/internal/rag-service/server/grpc/rag/service.go @@ -13,6 +13,7 @@ import ( "github.com/UnicomAI/wanwu/internal/rag-service/client/model" "github.com/UnicomAI/wanwu/internal/rag-service/client/orm" message_builder "github.com/UnicomAI/wanwu/internal/rag-service/service/message-builder" + "github.com/UnicomAI/wanwu/pkg/db" grpc_util "github.com/UnicomAI/wanwu/pkg/grpc-util" http_client "github.com/UnicomAI/wanwu/pkg/http-client" "github.com/UnicomAI/wanwu/pkg/log" @@ -278,7 +279,7 @@ func (s *Service) UpdateRagConfig(ctx context.Context, in *rag_service.UpdateRag MetaParams: metaParams, UseGraph: kbGlobalConfig.UseGraph, }, - QAKnowledgebaseConfig: qaKnowledgeConfig, + QAKnowledgebaseConfig: db.LongText(qaKnowledgeConfig), SensitiveConfig: model.SensitiveConfig{ Enable: in.SensitiveConfig.Enable, TableIds: sensitiveIds, diff --git a/pkg/db/types.go b/pkg/db/types.go new file mode 100644 index 000000000..4c36256cd --- /dev/null +++ b/pkg/db/types.go @@ -0,0 +1,74 @@ +package db + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + + "gorm.io/gorm" + "gorm.io/gorm/schema" +) + +// JSON 映射类型 +type JSONMap map[string]interface{} + +func (j JSONMap) Value() (driver.Value, error) { + return json.Marshal(j) +} + +func (j *JSONMap) Scan(value interface{}) error { + if value == nil { + return nil + } + bytes, ok := value.([]byte) + if !ok { + return fmt.Errorf("failed to unmarshal JSON value") + } + return json.Unmarshal(bytes, j) +} + +func (JSONMap) GormDBDataType(db *gorm.DB, field *schema.Field) string { + switch db.Dialector.Name() { + case "mysql": + return "JSON" + case "postgres": + return "JSONB" + default: + return "TEXT" + } +} + +// LongText 长文本类型 +type LongText string + +func (l LongText) Value() (driver.Value, error) { + return string(l), nil +} + +func (l *LongText) Scan(value interface{}) error { + if value == nil { + *l = "" + return nil + } + + switch v := value.(type) { + case string: + *l = LongText(v) + case []byte: + *l = LongText(v) + default: + return fmt.Errorf("failed to scan LongText value: %v", value) + } + return nil +} + +func (LongText) GormDBDataType(db *gorm.DB, field *schema.Field) string { + switch db.Dialector.Name() { + case "mysql": + return "LONGTEXT" + case "postgres": + return "TEXT" + default: + return "TEXT" + } +} diff --git a/pkg/es/api_assistant.go b/pkg/es/api_assistant.go deleted file mode 100644 index 0c6e3a698..000000000 --- a/pkg/es/api_assistant.go +++ /dev/null @@ -1,216 +0,0 @@ -package es - -import ( - "context" - "fmt" - - "github.com/UnicomAI/wanwu/pkg/log" -) - -var ( - _esAssistant *client -) - -func InitAssistant(ctx context.Context, cfg Config) error { - if _esAssistant != nil { - return fmt.Errorf("ES assistant客户端已经初始化") - } - c, err := newClient(ctx, cfg) - if err != nil { - return err - } - _esAssistant = c - return nil -} - -func StopAssistant() { - if _esAssistant != nil { - _esAssistant.Stop() - _esAssistant = nil - } -} - -func Assistant() *client { - return _esAssistant -} - -func InitESIndexTemplate(ctx context.Context) error { - templateName := "conversation_detail_infos_template" - - // 检查模板是否已存在 - exists, err := Assistant().IndexTemplateExists(ctx, templateName) - if err != nil { - return fmt.Errorf("检查ES索引模板失败: %v", err) - } - - if exists { - log.Infof("ES索引模板已存在: %s", templateName) - return nil - } - - // 创建索引模板 - template := `{ - "index_patterns": [ - "conversation_detail_infos_*" - ], - "template": { - "mappings": { - "properties": { - "id": { - "type": "keyword", - "index": true - }, - "assistantId": { - "type": "keyword", - "index": true - }, - "conversationId": { - "type": "keyword", - "index": true - }, - "prompt": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword" - } - } - }, - "sysPrompt": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword" - } - } - }, - "algPrompt": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword" - } - } - }, - "requestFileIds": { - "type": "keyword", - "index": false - }, - "requestFileUrls": { - "type": "keyword", - "index": false - }, - "response": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword" - } - } - }, - "algResponse": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword" - } - } - }, - "responseFileIds": { - "type": "keyword", - "index": false - }, - "responseFileUrls": { - "type": "keyword", - "index": false - }, - "searchList": { - "type": "keyword", - "index": false - }, - "fileInfo": { - "type": "object", - "properties": { - "fileName": { - "type": "text", - "fields": { - "keyword": { - "type": "keyword" - } - } - }, - "fileSize": { - "type": "long" - }, - "fileUrl": { - "type": "text", - "index": false - } - } - }, - "createdBy": { - "type": "keyword", - "index": true - }, - "ts": { - "type": "date" - }, - "timestamp": { - "type": "long" - }, - "qaType": { - "type": "integer" - }, - "createdAt": { - "type": "date" - }, - "modelId": { - "type": "keyword", - "fields": { - "keyword": { - "type": "keyword" - } - } - }, - "modelVersion": { - "type": "keyword", - "fields": { - "keyword": { - "type": "keyword" - } - } - }, - "finish": { - "type": "integer", - "index": true - }, - "fileFormat": { - "type": "text", - "index": false - }, - "fileSize": { - "type": "long", - "index": false - }, - "fileName": { - "type": "text", - "index": false - }, - "videoStatus": { - "type": "integer" - }, - "responseId": { - "type": "keyword", - "index": true - } - } - } - } - }` - if err := Assistant().CreateIndexTemplate(ctx, templateName, template); err != nil { - return fmt.Errorf("创建ES索引模板失败: %v", err) - } - - log.Infof("成功创建ES索引模板: %s", templateName) - return nil -} diff --git a/pkg/es/client.go b/pkg/es/client.go deleted file mode 100644 index d1a8b658c..000000000 --- a/pkg/es/client.go +++ /dev/null @@ -1,278 +0,0 @@ -package es - -import ( - "context" - "crypto/tls" - "encoding/json" - "fmt" - "net/http" - "strings" - "sync" - - "github.com/UnicomAI/wanwu/pkg/log" - "github.com/elastic/go-elasticsearch/v8" -) - -type Config struct { - Address string `json:"address" mapstructure:"address"` - Username string `json:"username" mapstructure:"username"` - Password string `json:"password" mapstructure:"password"` -} - -type client struct { - ctx context.Context - cli *elasticsearch.Client - - mutex sync.Mutex - stopped bool - stop chan struct{} -} - -func newClient(ctx context.Context, c Config) (*client, error) { - // 智能判断协议,如果地址没有协议前缀,则尝试HTTPS,失败后尝试HTTP - addresses := []string{} - - // 如果地址已经包含协议,直接使用 - if strings.HasPrefix(c.Address, "http://") || strings.HasPrefix(c.Address, "https://") { - addresses = append(addresses, c.Address) - } else { - // 优先尝试HTTPS,然后HTTP - addresses = append(addresses, "https://"+c.Address, "http://"+c.Address) - } - - var lastErr error - - // 尝试每个地址 - for _, addr := range addresses { - cfg := elasticsearch.Config{ - Addresses: []string{addr}, - Username: c.Username, - Password: c.Password, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - }, - } - - esClient, err := elasticsearch.NewClient(cfg) - if err != nil { - lastErr = fmt.Errorf("创建ES客户端失败 [%s]: %v", addr, err) - log.Warnf("创建ES客户端失败,地址: %s, 错误: %v", addr, err) - continue - } - - // 测试连接 - res, err := esClient.Info() - if err != nil { - lastErr = fmt.Errorf("ES连接测试失败 [%s]: %v", addr, err) - log.Warnf("ES连接测试失败,地址: %s, 错误: %v", addr, err) - continue - } - - if res != nil { - defer res.Body.Close() - - if res.IsError() { - lastErr = fmt.Errorf("ES连接响应错误 [%s]: %s", addr, res.String()) - log.Warnf("ES连接响应错误,地址: %s, 响应: %s", addr, res.String()) - continue - } - } - - log.Infof("ES连接成功,地址: %s", addr) - return &client{ - ctx: ctx, - cli: esClient, - stop: make(chan struct{}, 1), - }, nil - } - - // 所有地址都失败了 - if lastErr != nil { - return nil, lastErr - } - - return nil, fmt.Errorf("无法连接到ES,尝试的地址: %v", addresses) -} - -func (c *client) Stop() { - c.mutex.Lock() - if c.stopped { - log.Errorf("ES客户端已经停止") - c.mutex.Unlock() - return - } - c.stopped = true - close(c.stop) - c.mutex.Unlock() - log.Infof("ES客户端停止") -} - -func (c *client) Cli() *elasticsearch.Client { - return c.cli -} - -// 写入数据到指定索引 -func (c *client) IndexDocument(ctx context.Context, index string, document interface{}) error { - docJSON, err := json.Marshal(document) - if err != nil { - return fmt.Errorf("序列化文档失败: %v", err) - } - - res, err := c.cli.Index( - index, - strings.NewReader(string(docJSON)), - c.cli.Index.WithContext(ctx), - c.cli.Index.WithRefresh("true"), - ) - if err != nil { - return fmt.Errorf("写入ES失败: %v", err) - } - defer res.Body.Close() - - if res.IsError() { - return fmt.Errorf("ES写入响应错误: %s", res.String()) - } - - log.Infof("成功写入ES,索引: %s", index) - return nil -} - -// 根据指定字段条件查询所有数据 -func (c *client) SearchByFields(ctx context.Context, index string, fieldConditions map[string]interface{}, from, size int) ([]json.RawMessage, int64, error) { - query := map[string]interface{}{ - "query": map[string]interface{}{ - "bool": map[string]interface{}{ - "must": buildMustQuery(fieldConditions), - }, - }, - "from": from, - "size": size, - "sort": []map[string]interface{}{ - { - "createdAt": map[string]interface{}{ - "order": "desc", - }, - }, - }, - } - - queryJSON, err := json.Marshal(query) - if err != nil { - return nil, 0, fmt.Errorf("序列化查询失败: %v", err) - } - - res, err := c.cli.Search( - c.cli.Search.WithContext(ctx), - c.cli.Search.WithIndex(index), - c.cli.Search.WithBody(strings.NewReader(string(queryJSON))), - ) - if err != nil { - return nil, 0, fmt.Errorf("ES查询失败: %v", err) - } - defer res.Body.Close() - - if res.IsError() { - return nil, 0, fmt.Errorf("ES查询响应错误: %s", res.String()) - } - - var result map[string]interface{} - if err := json.NewDecoder(res.Body).Decode(&result); err != nil { - return nil, 0, fmt.Errorf("解析查询结果失败: %v", err) - } - - hits, ok := result["hits"].(map[string]interface{}) - if !ok { - return nil, 0, fmt.Errorf("无效的查询结果格式") - } - - total, ok := hits["total"].(map[string]interface{}) - if !ok { - return nil, 0, fmt.Errorf("无效的总数格式") - } - - totalValue, ok := total["value"].(float64) - if !ok { - return nil, 0, fmt.Errorf("无效的总数值") - } - - hitsList, ok := hits["hits"].([]interface{}) - if !ok { - return nil, 0, fmt.Errorf("无效的命中列表格式") - } - - var documents []json.RawMessage - for _, hit := range hitsList { - hitMap, ok := hit.(map[string]interface{}) - if !ok { - continue - } - source, ok := hitMap["_source"] - if !ok { - continue - } - sourceJSON, err := json.Marshal(source) - if err != nil { - continue - } - documents = append(documents, sourceJSON) - } - - log.Infof("ES查询成功,索引: %s, 总数: %d, 返回: %d", index, int64(totalValue), len(documents)) - return documents, int64(totalValue), nil -} - -// 创建索引模板 -func (c *client) CreateIndexTemplate(ctx context.Context, templateName string, templateBody string) error { - res, err := c.cli.Indices.PutIndexTemplate( - templateName, - strings.NewReader(templateBody), - c.cli.Indices.PutIndexTemplate.WithContext(ctx), - ) - if err != nil { - return fmt.Errorf("创建索引模板失败: %v", err) - } - defer res.Body.Close() - - if res.IsError() { - return fmt.Errorf("创建索引模板响应错误: %s", res.String()) - } - - log.Infof("成功创建索引模板: %s", templateName) - return nil -} - -// 检查索引模板是否存在 -func (c *client) IndexTemplateExists(ctx context.Context, templateName string) (bool, error) { - res, err := c.cli.Indices.GetIndexTemplate( - c.cli.Indices.GetIndexTemplate.WithName(templateName), - c.cli.Indices.GetIndexTemplate.WithContext(ctx), - ) - if err != nil { - return false, fmt.Errorf("检查索引模板失败: %v", err) - } - defer res.Body.Close() - - if res.StatusCode == 404 { - return false, nil - } - - if res.IsError() { - return false, fmt.Errorf("检查索引模板响应错误: %s", res.String()) - } - - return true, nil -} - -func buildMustQuery(conditions map[string]interface{}) []map[string]interface{} { - var mustQuery []map[string]interface{} - for field, value := range conditions { - mustQuery = append(mustQuery, map[string]interface{}{ - "term": map[string]interface{}{ - field: value, - }, - }) - } - return mustQuery -}