Skip to content

Commit c7a751f

Browse files
committed
feat(database): support custom PostgreSQL schema via DATABASE_SCHEMA
Allows multiple projects to share a single PostgreSQL database (e.g. Supabase) by isolating tables in a configurable schema. When DATABASE_SCHEMA is set, startup creates the schema if missing and pins search_path on every pool connection via libpq options. Empty value preserves existing behavior. Closes #119
1 parent 4694c54 commit c7a751f

6 files changed

Lines changed: 121 additions & 6 deletions

File tree

.env.example

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ DATABASE_PORT=5432
6161
DATABASE_USER=codex2api
6262
DATABASE_PASSWORD=codex2api
6363
DATABASE_NAME=codex2api
64+
# 可选:自定义 schema(适用于 Supabase 等多项目共享 database 的场景)
65+
# 启动时会自动 CREATE SCHEMA IF NOT EXISTS,并在所有连接上 SET search_path 到该 schema
66+
# 仅允许字母/数字/下划线,长度 ≤63;留空时使用数据库默认 search_path(通常是 public)
67+
# DATABASE_SCHEMA=codex2api
6468

6569
# ---- Redis ----
6670
CACHE_DRIVER=redis

config/config.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,24 @@ package config
33
import (
44
"fmt"
55
"os"
6+
"regexp"
67
"strconv"
78
"strings"
89

910
"github.com/joho/godotenv"
1011
)
1112

13+
// schemaNameRegex 限定 PostgreSQL schema 名为 ASCII 标识符,避免 DSN/DDL 注入。
14+
var schemaNameRegex = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`)
15+
16+
// IsValidSchemaName 校验 PostgreSQL schema 名(首字母为字母或下划线,余下为字母/数字/下划线,长度 ≤63)。
17+
func IsValidSchemaName(name string) bool {
18+
if name == "" || len(name) > 63 {
19+
return false
20+
}
21+
return schemaNameRegex.MatchString(name)
22+
}
23+
1224
// DatabaseConfig 数据库核心配置。
1325
type DatabaseConfig struct {
1426
Driver string
@@ -18,6 +30,7 @@ type DatabaseConfig struct {
1830
User string
1931
Password string
2032
DBName string
33+
Schema string // PostgreSQL schema(search_path);空值保持数据库默认行为
2134
SSLMode string
2235
}
2336

@@ -30,8 +43,14 @@ func (d *DatabaseConfig) DSN() string {
3043
if sslMode == "" {
3144
sslMode = "disable"
3245
}
33-
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
46+
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
3447
d.Host, d.Port, d.User, d.Password, d.DBName, sslMode)
48+
if d.Schema != "" {
49+
// 通过 libpq options 在连接启动时设置 search_path,覆盖连接池中的所有连接。
50+
// schema 已在 Load() 阶段做白名单校验,此处可安全拼接。
51+
dsn += fmt.Sprintf(" options='-c search_path=%s,public'", d.Schema)
52+
}
53+
return dsn
3554
}
3655

3756
// Label 返回用于展示的数据库标签。
@@ -139,6 +158,12 @@ func Load(envPath string) (*Config, error) {
139158
cfg.Database.User = os.Getenv("DATABASE_USER")
140159
cfg.Database.Password = os.Getenv("DATABASE_PASSWORD")
141160
cfg.Database.DBName = os.Getenv("DATABASE_NAME")
161+
if v := strings.TrimSpace(os.Getenv("DATABASE_SCHEMA")); v != "" {
162+
if !IsValidSchemaName(v) {
163+
return nil, fmt.Errorf("非法的 DATABASE_SCHEMA: %q(仅允许字母、数字、下划线,且不能以数字开头,长度不超过 63)", v)
164+
}
165+
cfg.Database.Schema = v
166+
}
142167
if v := os.Getenv("DATABASE_SSLMODE"); v != "" {
143168
cfg.Database.SSLMode = v
144169
}

config/config_test.go

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package config
22

3-
import "testing"
3+
import (
4+
"strings"
5+
"testing"
6+
)
47

58
func TestLoadDefaultsToPostgresAndRedis(t *testing.T) {
69
keys := []string{
@@ -301,3 +304,63 @@ func TestLoadReadsRedisTLSSettings(t *testing.T) {
301304
t.Fatal("Redis.InsecureSkipVerify = false, want true")
302305
}
303306
}
307+
308+
func TestLoadAcceptsValidDatabaseSchema(t *testing.T) {
309+
t.Setenv("DATABASE_DRIVER", "")
310+
t.Setenv("DATABASE_HOST", "postgres")
311+
t.Setenv("DATABASE_NAME", "postgres")
312+
t.Setenv("DATABASE_SCHEMA", "codex2api")
313+
t.Setenv("CACHE_DRIVER", "")
314+
t.Setenv("REDIS_ADDR", "redis:6379")
315+
316+
cfg, err := Load("__not_exists__.env")
317+
if err != nil {
318+
t.Fatalf("Load() 返回错误: %v", err)
319+
}
320+
if got := cfg.Database.Schema; got != "codex2api" {
321+
t.Fatalf("Database.Schema = %q, want codex2api", got)
322+
}
323+
dsn := cfg.Database.DSN()
324+
if !strings.Contains(dsn, "options='-c search_path=codex2api,public'") {
325+
t.Fatalf("DSN 未包含 search_path 选项: %s", dsn)
326+
}
327+
}
328+
329+
func TestLoadRejectsInvalidDatabaseSchema(t *testing.T) {
330+
cases := []string{
331+
"public; DROP TABLE users",
332+
"with space",
333+
"1leading-digit",
334+
"with-dash",
335+
"中文",
336+
strings.Repeat("a", 64),
337+
}
338+
for _, name := range cases {
339+
t.Run(name, func(t *testing.T) {
340+
t.Setenv("DATABASE_DRIVER", "")
341+
t.Setenv("DATABASE_HOST", "postgres")
342+
t.Setenv("DATABASE_SCHEMA", name)
343+
t.Setenv("CACHE_DRIVER", "")
344+
t.Setenv("REDIS_ADDR", "redis:6379")
345+
346+
if _, err := Load("__not_exists__.env"); err == nil {
347+
t.Fatalf("非法 schema %q 应当被拒绝,但 Load() 通过了", name)
348+
}
349+
})
350+
}
351+
}
352+
353+
func TestDSNOmitsSchemaWhenEmpty(t *testing.T) {
354+
d := DatabaseConfig{
355+
Driver: "postgres",
356+
Host: "h",
357+
Port: 5432,
358+
User: "u",
359+
Password: "p",
360+
DBName: "db",
361+
SSLMode: "disable",
362+
}
363+
if got := d.DSN(); strings.Contains(got, "search_path") {
364+
t.Fatalf("空 schema 时 DSN 不应包含 search_path: %s", got)
365+
}
366+
}

database/postgres.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
"sync/atomic"
1212
"time"
1313

14-
_ "github.com/lib/pq"
14+
"github.com/lib/pq"
1515
_ "modernc.org/sqlite"
1616
)
1717

@@ -181,13 +181,19 @@ type usageLogEntry struct {
181181
}
182182

183183
// New 创建数据库连接并自动建表。
184-
func New(driver string, dsn string) (*DB, error) {
184+
// schema 仅对 PostgreSQL 生效;为空时保持数据库默认 search_path。
185+
func New(driver string, dsn string, schema ...string) (*DB, error) {
185186
driver = normalizeDriver(driver)
186187
driverName := driver
187188
if driver == "sqlite" {
188189
driverName = "sqlite"
189190
}
190191

192+
pgSchema := ""
193+
if len(schema) > 0 {
194+
pgSchema = strings.TrimSpace(schema[0])
195+
}
196+
191197
conn, err := sql.Open(driverName, dsn)
192198
if err != nil {
193199
return nil, fmt.Errorf("连接数据库失败: %w", err)
@@ -228,6 +234,18 @@ func New(driver string, dsn string) (*DB, error) {
228234
if _, err := conn.ExecContext(ctx, "SET timezone = 'UTC'"); err != nil {
229235
return nil, fmt.Errorf("设置数据库时区失败: %w", err)
230236
}
237+
// 自定义 schema:确保 schema 存在并确认当前会话 search_path 已生效。
238+
// search_path 已通过 DSN 的 options=-c search_path=... 在所有连接启动时设置;
239+
// 这里仅做一次幂等的 CREATE SCHEMA + SET 兜底,便于首次部署时自动建好 schema。
240+
if pgSchema != "" {
241+
quoted := pq.QuoteIdentifier(pgSchema)
242+
if _, err := conn.ExecContext(ctx, "CREATE SCHEMA IF NOT EXISTS "+quoted); err != nil {
243+
return nil, fmt.Errorf("创建数据库 schema 失败: %w", err)
244+
}
245+
if _, err := conn.ExecContext(ctx, "SET search_path TO "+quoted+", public"); err != nil {
246+
return nil, fmt.Errorf("设置 search_path 失败: %w", err)
247+
}
248+
}
231249
}
232250
if err := db.migrate(ctx); err != nil {
233251
return nil, fmt.Errorf("数据库迁移失败: %w", err)

docs/CONFIGURATION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Codex2API 采用三层配置架构:
6969
| `DATABASE_USER` || - | PostgreSQL 用户名 |
7070
| `DATABASE_PASSWORD` || - | PostgreSQL 密码 |
7171
| `DATABASE_NAME` || - | PostgreSQL 数据库名 |
72+
| `DATABASE_SCHEMA` || - | PostgreSQL schema;适合 Supabase 等多项目共享 database 的场景。配置后启动时自动 `CREATE SCHEMA IF NOT EXISTS` 并将所有连接的 `search_path` 指向该 schema。仅允许字母/数字/下划线,长度 ≤63;留空保持默认(通常是 `public`)。|
7273
| `DATABASE_SSLMODE` || disable | SSL 模式: disable/require/verify-full |
7374

7475
### 生图工作台

main.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func main() {
4141
log.Printf("物理层配置加载成功: port=%d, database=%s, cache=%s", cfg.Port, cfg.Database.Label(), cfg.Cache.Label())
4242

4343
// 2. 初始化数据库
44-
db, err := database.New(cfg.Database.Driver, cfg.Database.DSN())
44+
db, err := database.New(cfg.Database.Driver, cfg.Database.DSN(), cfg.Database.Schema)
4545
if err != nil {
4646
log.Fatalf("数据库初始化失败: %v", err)
4747
}
@@ -50,7 +50,11 @@ func main() {
5050
case "sqlite":
5151
log.Printf("%s 连接成功: %s", cfg.Database.Label(), cfg.Database.Path)
5252
default:
53-
log.Printf("%s 连接成功: %s:%d/%s", cfg.Database.Label(), cfg.Database.Host, cfg.Database.Port, cfg.Database.DBName)
53+
if cfg.Database.Schema != "" {
54+
log.Printf("%s 连接成功: %s:%d/%s (schema=%s)", cfg.Database.Label(), cfg.Database.Host, cfg.Database.Port, cfg.Database.DBName, cfg.Database.Schema)
55+
} else {
56+
log.Printf("%s 连接成功: %s:%d/%s", cfg.Database.Label(), cfg.Database.Host, cfg.Database.Port, cfg.Database.DBName)
57+
}
5458
}
5559

5660
// 3. 读取运行时的系统逻辑设置(需在缓存初始化之前,以获取连接池大小)

0 commit comments

Comments
 (0)