Skip to content

Commit 6b50ddb

Browse files
committed
fix(gen): fix gen model
1 parent 21c21b2 commit 6b50ddb

8 files changed

Lines changed: 593 additions & 33 deletions

File tree

cmd/gen.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,36 @@ var genCmd = &cobra.Command{
3737
config.C.Gen.Home = filepath.Join(home, ".jzero", "templates", Version)
3838
}
3939
embeded.Home = config.C.Gen.Home
40+
41+
// 兼容 model flag, MarkDeprecated
42+
if config.C.Gen.ModelMysqlStrict {
43+
config.C.Gen.ModelStrict = config.C.Gen.ModelMysqlStrict
44+
}
45+
if len(config.C.Gen.ModelMysqlIgnoreColumns) > 0 {
46+
config.C.Gen.ModelIgnoreColumns = config.C.Gen.ModelMysqlIgnoreColumns
47+
}
48+
if config.C.Gen.ModelMysqlDDLDatabase != "" {
49+
config.C.Gen.ModelDDLDatabase = config.C.Gen.ModelMysqlDDLDatabase
50+
}
51+
if config.C.Gen.ModelMysqlDatasource {
52+
config.C.Gen.ModelDatasource = config.C.Gen.ModelMysqlDatasource
53+
}
54+
if config.C.Gen.ModelMysqlDatasourceUrl != "" {
55+
config.C.Gen.ModelDatasourceUrl = config.C.Gen.ModelMysqlDatasourceUrl
56+
}
57+
if len(config.C.Gen.ModelMysqlDatasourceTable) > 0 {
58+
config.C.Gen.ModelDatasourceTable = config.C.Gen.ModelMysqlDatasourceTable
59+
}
60+
if config.C.Gen.ModelMysqlCache {
61+
config.C.Gen.ModelCache = config.C.Gen.ModelMysqlCache
62+
}
63+
if config.C.Gen.ModelMysqlCachePrefix != "" {
64+
config.C.Gen.ModelCachePrefix = config.C.Gen.ModelMysqlCachePrefix
65+
}
66+
if config.C.Gen.ModelMysqlCreateTableDDL {
67+
config.C.Gen.ModelCreateTableDDL = config.C.Gen.ModelMysqlCreateTableDDL
68+
}
69+
4070
return gen.Run(false)
4171
},
4272
SilenceUsage: true,

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ require (
88
github.com/dave/dst v0.27.3
99
github.com/fsgo/go_fmt v0.6.3
1010
github.com/go-git/go-git/v5 v5.16.0
11-
github.com/go-sql-driver/mysql v1.9.2
1211
github.com/golang-migrate/migrate/v4 v4.18.3
1312
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3
1413
github.com/hashicorp/go-version v1.7.0
@@ -63,6 +62,7 @@ require (
6362
github.com/go-git/go-billy/v5 v5.6.2 // indirect
6463
github.com/go-logr/logr v1.4.2 // indirect
6564
github.com/go-logr/stdr v1.2.2 // indirect
65+
github.com/go-sql-driver/mysql v1.9.2 // indirect
6666
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
6767
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
6868
github.com/golang/protobuf v1.5.4 // indirect

internal/gen/genmodel/gen.go

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"fmt"
7+
"github.com/jzero-io/jzero/pkg/dsn"
78
"go/ast"
89
goformat "go/format"
910
goparser "go/parser"
@@ -15,7 +16,6 @@ import (
1516
"strings"
1617
"sync"
1718

18-
"github.com/go-sql-driver/mysql"
1919
"github.com/jzero-io/jzero-contrib/filex"
2020
"github.com/pkg/errors"
2121
"github.com/samber/lo"
@@ -72,17 +72,14 @@ func (jm *JzeroModel) Gen() error {
7272
return errors.New("postgres model only support datasource mode")
7373
}
7474

75-
if config.C.Gen.ModelMysqlDatasource || config.C.Gen.ModelDatasource {
75+
if config.C.Gen.ModelDatasource {
7676
if jm.IsNew {
7777
fmt.Printf("%s you are using mysql datesource to generate model code, please manual execute jzero gen command\n", color.WithColor("Detected", color.FgRed))
7878
return nil
7979
}
8080

8181
switch config.C.Gen.ModelDriver {
8282
case "mysql":
83-
if config.C.Gen.ModelMysqlDatasourceUrl != "" {
84-
config.C.Gen.ModelDatasourceUrl = config.C.Gen.ModelMysqlDatasourceUrl
85-
}
8683
sqlConn = sqlx.NewMysql(config.C.Gen.ModelDatasourceUrl)
8784
case "postgres":
8885
sqlConn = postgres.New(config.C.Gen.ModelDatasourceUrl)
@@ -95,13 +92,13 @@ func (jm *JzeroModel) Gen() error {
9592
return err
9693
}
9794

98-
fmt.Printf("%s to generate ddl from %s\n", color.WithColor("Start", color.FgGreen), config.C.Gen.ModelMysqlDatasourceUrl)
95+
fmt.Printf("%s to generate ddl from %s\n", color.WithColor("Start", color.FgGreen), config.C.Gen.ModelDatasourceUrl)
9996

10097
writeTables, err := jm.GenDDL(sqlConn, tables)
10198
if err != nil {
10299
return err
103100
}
104-
if !config.C.Gen.ModelMysqlCreateTableDDL || !config.C.Gen.ModelCreateTableDDL {
101+
if !config.C.Gen.ModelCreateTableDDL {
105102
defer func() {
106103
for _, v := range writeTables {
107104
if err = os.Remove(v); err != nil {
@@ -128,7 +125,7 @@ func (jm *JzeroModel) Gen() error {
128125
}
129126

130127
switch {
131-
case config.C.Gen.GitChange && filex.DirExists(filepath.Join(config.C.Wd(), ".git")) && len(config.C.Gen.Desc) == 0 && !config.C.Gen.ModelMysqlDatasource:
128+
case config.C.Gen.GitChange && filex.DirExists(filepath.Join(config.C.Wd(), ".git")) && len(config.C.Gen.Desc) == 0 && !config.C.Gen.ModelDatasource:
132129
m, _, err := gitstatus.ChangedFiles(config.C.SqlDir(), ".sql")
133130
if err == nil {
134131
genCodeSqlFiles = append(genCodeSqlFiles, m...)
@@ -178,18 +175,18 @@ func (jm *JzeroModel) Gen() error {
178175
var mu sync.Mutex
179176

180177
if len(genCodeSqlFiles) != 0 {
181-
if config.C.Gen.ModelMysqlDatasource || config.C.Gen.ModelDatasource {
178+
if config.C.Gen.ModelDatasource {
182179
tables, err := getAllTables(sqlConn, config.C.Gen.ModelDriver)
183180
if err != nil {
184181
return err
185182
}
186-
if (len(config.C.Gen.ModelMysqlDatasourceTable) != 0 && config.C.Gen.ModelMysqlDatasourceTable[0] != "*") || (len(config.C.Gen.ModelDatasourceTable) != 0 && config.C.Gen.ModelDatasourceTable[0] != "*") {
183+
if len(config.C.Gen.ModelDatasourceTable) != 0 && config.C.Gen.ModelDatasourceTable[0] != "*" {
187184
for _, v := range tables {
188-
if lo.Contains(config.C.Gen.ModelMysqlDatasourceTable, cast.ToString(v)) || lo.Contains(config.C.Gen.ModelDatasourceTable, cast.ToString(v)) {
185+
if lo.Contains(config.C.Gen.ModelDatasourceTable, cast.ToString(v)) {
189186
allTables = append(allTables, v)
190187
}
191188
}
192-
} else if (len(config.C.Gen.ModelMysqlDatasourceTable) != 0 && config.C.Gen.ModelMysqlDatasourceTable[0] == "*") || (len(config.C.Gen.ModelDatasourceTable) != 0 && config.C.Gen.ModelDatasourceTable[0] == "*") {
189+
} else if len(config.C.Gen.ModelDatasourceTable) != 0 && config.C.Gen.ModelDatasourceTable[0] == "*" {
193190
allTables = tables
194191
}
195192
for _, f := range allFiles {
@@ -203,7 +200,7 @@ func (jm *JzeroModel) Gen() error {
203200
var eg errgroup.Group
204201
for _, f := range allFiles {
205202
eg.Go(func() error {
206-
tableParsers, err := parser.Parse(filepath.Join(config.C.Wd(), f), "", config.C.Gen.ModelMysqlStrict)
203+
tableParsers, err := parser.Parse(filepath.Join(config.C.Wd(), f), "", config.C.Gen.ModelStrict)
207204
if err != nil {
208205
return err
209206
}
@@ -237,14 +234,15 @@ func (jm *JzeroModel) Gen() error {
237234
modelDir := filepath.Join("internal", "model", strings.ToLower(bf[0:len(bf)-len(path.Ext(bf))]))
238235

239236
var ddlDatabase string
240-
if config.C.Gen.ModelMysqlDDLDatabase != "" {
241-
ddlDatabase = config.C.Gen.ModelMysqlDDLDatabase
242-
} else if config.C.Gen.ModelMysqlDatasourceUrl != "" {
243-
mysqlDsn, err := mysql.ParseDSN(config.C.Gen.ModelMysqlDatasourceUrl)
237+
if config.C.Gen.ModelDatasource {
238+
meta, err := dsn.ParseDSN(config.C.Gen.ModelDriver, config.C.Gen.ModelDatasourceUrl)
244239
if err != nil {
245240
return err
246241
}
247-
ddlDatabase = mysqlDsn.DBName
242+
ddlDatabase = meta[dsn.DBName]
243+
}
244+
if config.C.Gen.ModelDDLDatabase != "" {
245+
ddlDatabase = config.C.Gen.ModelDDLDatabase
248246
}
249247

250248
if config.C.Gen.ModelDriver == "postgres" {
@@ -255,18 +253,15 @@ func (jm *JzeroModel) Gen() error {
255253
return errors.Errorf("gen model code meet error. Err: %s:%s", err.Error(), resp)
256254
}
257255
} else {
258-
if config.C.Gen.ModelMysqlCache {
259-
config.C.Gen.ModelCache = true
260-
}
261-
cmd := exec.Command("goctl", "model", "mysql", "ddl", "--database", ddlDatabase, "--src", f, "--dir", modelDir, "--home", goctlHome, "--style", config.C.Gen.Style, "-i", strings.Join(config.C.Gen.ModelMysqlIgnoreColumns, ","), "--cache="+fmt.Sprintf("%t", config.C.Gen.ModelCache), "--strict="+fmt.Sprintf("%t", config.C.Gen.ModelMysqlStrict))
256+
cmd := exec.Command("goctl", "model", "mysql", "ddl", "--database", ddlDatabase, "--src", f, "--dir", modelDir, "--home", goctlHome, "--style", config.C.Gen.Style, "-i", strings.Join(config.C.Gen.ModelIgnoreColumns, ","), "--cache="+fmt.Sprintf("%t", config.C.Gen.ModelCache), "--strict="+fmt.Sprintf("%t", config.C.Gen.ModelStrict))
262257
logx.Debug(cmd.String())
263258
resp, err := cmd.CombinedOutput()
264259
if err != nil {
265260
return errors.Errorf("gen model code meet error. Err: %s:%s", err.Error(), resp)
266261
}
267262
}
268263

269-
if (config.C.Gen.ModelMysqlCachePrefix != "" && config.C.Gen.ModelMysqlCache) || (config.C.Gen.ModelCachePrefix != "" && config.C.Gen.ModelCache) {
264+
if config.C.Gen.ModelCachePrefix != "" && config.C.Gen.ModelCache {
270265
for _, tp := range tableParsers {
271266
namingFormat, err := format.FileNamingFormat(config.C.Gen.Style, tp.Name.Source())
272267
if err != nil {
@@ -309,11 +304,7 @@ func (jm *JzeroModel) addModelCachePrefix(fp string) error {
309304
if strings.HasPrefix(name.Name, "cache") && strings.HasSuffix(name.Name, "Prefix") {
310305
value := valueSpec.Values[i]
311306
if basicLit, ok := value.(*ast.BasicLit); ok {
312-
if config.C.Gen.ModelCachePrefix != "" {
313-
basicLit.Value = fmt.Sprintf(`"%s%s"`, config.C.Gen.ModelCachePrefix, strings.ReplaceAll(basicLit.Value, "\"", ""))
314-
} else if config.C.Gen.ModelMysqlCachePrefix != "" {
315-
basicLit.Value = fmt.Sprintf(`"%s%s"`, config.C.Gen.ModelMysqlCachePrefix, strings.ReplaceAll(basicLit.Value, "\"", ""))
316-
}
307+
basicLit.Value = fmt.Sprintf(`"%s%s"`, config.C.Gen.ModelCachePrefix, strings.ReplaceAll(basicLit.Value, "\"", ""))
317308
}
318309
}
319310
}

internal/gen/genmodel/plugins.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func (jm *JzeroModel) GenRegister(tables []string) error {
3838
template, err := templatex.ParseTemplate(map[string]any{
3939
"Imports": imports,
4040
"TablePackages": tablePackages,
41-
"withCache": config.C.Gen.ModelMysqlCache,
41+
"withCache": config.C.Gen.ModelCache,
4242
}, embeded.ReadTemplateFile(filepath.Join("plugins", "model", "model.go.tpl")))
4343
if err != nil {
4444
return err
@@ -85,14 +85,14 @@ func (jm *JzeroModel) GenDDL(sqlConn sqlx.SqlConn, tables []string) ([]string, e
8585
var writeTables []string
8686
for _, v := range tables {
8787
if s, ok := tableDDLMap.Load(v); ok {
88-
if len(config.C.Gen.ModelMysqlDatasourceTable) != 0 && config.C.Gen.ModelMysqlDatasourceTable[0] != "*" {
89-
if lo.Contains(config.C.Gen.ModelMysqlDatasourceTable, cast.ToString(v)) {
88+
if len(config.C.Gen.ModelDatasourceTable) != 0 && config.C.Gen.ModelDatasourceTable[0] != "*" {
89+
if lo.Contains(config.C.Gen.ModelDatasourceTable, cast.ToString(v)) {
9090
writeTables = append(writeTables, filepath.Join("desc", "sql", fmt.Sprintf("%s.sql", v)))
9191
if err := os.WriteFile(filepath.Join("desc", "sql", fmt.Sprintf("%s.sql", v)), []byte(cast.ToString(s)), 0o644); err != nil {
9292
return nil, err
9393
}
9494
}
95-
} else if len(config.C.Gen.ModelMysqlDatasourceTable) != 0 && config.C.Gen.ModelMysqlDatasourceTable[0] == "*" {
95+
} else if len(config.C.Gen.ModelDatasourceTable) != 0 && config.C.Gen.ModelDatasourceTable[0] == "*" {
9696
writeTables = append(writeTables, filepath.Join("desc", "sql", fmt.Sprintf("%s.sql", v)))
9797
if err := os.WriteFile(filepath.Join("desc", "sql", fmt.Sprintf("%s.sql", v)), []byte(cast.ToString(s)), 0o644); err != nil {
9898
return nil, err

pkg/dsn/dsn.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Unless explicitly stated otherwise all files in this repository are licensed
2+
// under the Apache License Version 2.0.
3+
// This product includes software developed at Datadog (https://www.datadoghq.com/).
4+
// Copyright 2016 Datadog, Inc.
5+
6+
package dsn
7+
8+
import (
9+
"net"
10+
"net/url"
11+
"strings"
12+
)
13+
14+
const (
15+
// DBApplication indicates the application using the database.
16+
DBApplication = "db.application"
17+
// DBName indicates the database name.
18+
DBName = "db.name"
19+
// DBUser indicates the user name of Database, e.g. "readonly_user" or "reporting_user".
20+
DBUser = "db.user"
21+
TargetHost = "out.host"
22+
TargetPort = "out.port"
23+
)
24+
25+
// ParseDSN parses various supported DSN types into a map of key/value pairs which can be used as valid tags.
26+
func ParseDSN(driverName, dsn string) (meta map[string]string, err error) {
27+
meta = make(map[string]string)
28+
switch driverName {
29+
case "mysql":
30+
meta, err = parseMySQLDSN(dsn)
31+
if err != nil {
32+
return
33+
}
34+
case "postgres", "pgx":
35+
meta, err = parsePostgresDSN(dsn)
36+
if err != nil {
37+
return
38+
}
39+
default:
40+
// Try to parse the DSN and see if the scheme contains a known driver name.
41+
u, e := url.Parse(dsn)
42+
if e != nil {
43+
// dsn is not a valid URL, so just ignore
44+
return
45+
}
46+
if driverName != u.Scheme {
47+
// In some cases the driver is registered under a non-official name.
48+
// For example, "Test" may be the registered name with a DSN of "postgres://postgres:postgres@127.0.0.1:5432/fakepreparedb"
49+
// for the purposes of testing/mocking.
50+
// In these cases, we try to parse the DSN based upon the DSN itself, instead of the registered driver name
51+
return ParseDSN(u.Scheme, dsn)
52+
}
53+
}
54+
return reduceKeys(meta), nil
55+
}
56+
57+
// reduceKeys takes a map containing parsed DSN information and returns a new
58+
// map containing only the keys relevant as tracing tags, if any.
59+
func reduceKeys(meta map[string]string) map[string]string {
60+
var keysOfInterest = map[string]string{
61+
"user": DBUser,
62+
"application_name": DBApplication,
63+
"dbname": DBName,
64+
"host": TargetHost,
65+
"port": TargetPort,
66+
}
67+
m := make(map[string]string)
68+
for k, v := range meta {
69+
if nk, ok := keysOfInterest[k]; ok {
70+
m[nk] = v
71+
}
72+
}
73+
return m
74+
}
75+
76+
// parseMySQLDSN parses a mysql-type dsn into a map.
77+
func parseMySQLDSN(dsn string) (m map[string]string, err error) {
78+
var cfg *mySQLConfig
79+
if cfg, err = mySQLConfigFromDSN(dsn); err == nil {
80+
host, port, _ := net.SplitHostPort(cfg.Addr)
81+
m = map[string]string{
82+
"user": cfg.User,
83+
"host": host,
84+
"port": port,
85+
"dbname": cfg.DBName,
86+
}
87+
return m, nil
88+
}
89+
return nil, err
90+
}
91+
92+
// parsePostgresDSN parses a postgres-type dsn into a map.
93+
func parsePostgresDSN(dsn string) (map[string]string, error) {
94+
var err error
95+
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
96+
// url form, convert to opts
97+
dsn, err = parseURL(dsn)
98+
if err != nil {
99+
return nil, err
100+
}
101+
}
102+
meta := make(map[string]string)
103+
if err := parseOpts(dsn, meta); err != nil {
104+
return nil, err
105+
}
106+
// remove sensitive information
107+
delete(meta, "password")
108+
return meta, nil
109+
}

pkg/dsn/dsn_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package dsn
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
)
7+
8+
func TestParseDSN(t *testing.T) {
9+
// Add tests for ParseDSN function
10+
t.Run("Test ParseDSN", func(t *testing.T) {
11+
// Add test cases for ParseDSN function
12+
t.Run("Test ParseDSN with mysql", func(t *testing.T) {
13+
// Add test cases for ParseDSN function with mysql
14+
t.Run("Test ParseDSN with mysql and valid dsn", func(t *testing.T) {
15+
// Add test cases for ParseDSN function with mysql and valid dsn
16+
meta, err := ParseDSN("mysql", "user:password@tcp(localhost:3306)/dbname")
17+
if err != nil {
18+
t.Errorf("ParseDSN() error = %v", err)
19+
}
20+
if meta[DBUser] != "user" {
21+
t.Errorf("ParseDSN() user = %v, want %v", meta["user"], "user")
22+
}
23+
if meta[TargetHost] != "localhost" {
24+
t.Errorf("ParseDSN() host = %v, want %v", meta["host"], "localhost")
25+
}
26+
if meta[TargetPort] != "3306" {
27+
t.Errorf("ParseDSN() port = %v, want %v", meta["port"], "3306")
28+
}
29+
if meta[DBName] != "dbname" {
30+
t.Errorf("ParseDSN() dbname = %v, want %v", meta["dbname"], "dbname")
31+
}
32+
})
33+
})
34+
})
35+
36+
t.Run("Test ParseDSN with postgres", func(t *testing.T) {
37+
// Add test cases for ParseDSN function with postgres
38+
t.Run("Test ParseDSN with postgres and valid dsn", func(t *testing.T) {
39+
// Add test cases for ParseDSN function with postgres and valid dsn
40+
meta, err := ParseDSN("postgres", "postgres://user:password@localhost:5432/dbname")
41+
if err != nil {
42+
t.Errorf("ParseDSN() error = %v", err)
43+
}
44+
fmt.Println(meta)
45+
if meta[DBUser] != "user" {
46+
t.Errorf("ParseDSN() user = %v, want %v", meta["user"], "user")
47+
}
48+
if meta[TargetHost] != "localhost" {
49+
t.Errorf("ParseDSN() host = %v, want %v", meta["host"], "localhost")
50+
}
51+
if meta[TargetPort] != "5432" {
52+
t.Errorf("ParseDSN() port = %v, want %v", meta["port"], "5432")
53+
}
54+
if meta[DBName] != "dbname" {
55+
t.Errorf("ParseDSN() dbname = %v, want %v", meta["dbname"], "dbname")
56+
}
57+
})
58+
})
59+
}

0 commit comments

Comments
 (0)