Skip to content

Commit 5783501

Browse files
authored
Merge pull request #636 from actiontech/feat-839
feat(sql_workbench): support DB2 in DMS->ODC bridge (Fixes actiontech/dms-ee#839)
2 parents ef29cba + 562b144 commit 5783501

2 files changed

Lines changed: 125 additions & 1 deletion

File tree

internal/sql_workbench/service/sql_workbench_service.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,12 @@ func (sqlWorkbenchService *SqlWorkbenchService) buildDatasourceBaseInfo(ctx cont
872872
return nil, err
873873
}
874874

875+
return sqlWorkbenchService.fillDatasourceBaseInfo(datasourceName, dbService, environmentID)
876+
}
877+
878+
// fillDatasourceBaseInfo 根据 dbService 字段填充 datasourceBaseInfo,
879+
// 不包含外部 IO(不查 project / 不连 DB),便于单元测试覆盖 DBType 分支逻辑。
880+
func (sqlWorkbenchService *SqlWorkbenchService) fillDatasourceBaseInfo(datasourceName string, dbService *biz.DBService, environmentID int64) (*datasourceBaseInfo, error) {
875881
baseInfo := &datasourceBaseInfo{
876882
Name: datasourceName,
877883
Type: sqlWorkbenchService.convertDBType(dbService.DBType),
@@ -892,6 +898,16 @@ func (sqlWorkbenchService *SqlWorkbenchService) buildDatasourceBaseInfo(ctx cont
892898
baseInfo.DefaultSchema, baseInfo.Properties, baseInfo.JDBCParams = buildMongoDatasourceOptions(dbService)
893899
}
894900

901+
// DB2 特殊处理:从 AdditionalParams.database_name 取默认 schema 透传到 ODC
902+
if dbService.DBType == "DB2" {
903+
databaseNameParam := dbService.AdditionalParams.GetParam("database_name")
904+
if databaseNameParam == nil || databaseNameParam.Value == "" {
905+
return nil, fmt.Errorf("DB2 数据源 %s 缺少 AdditionalParam database_name,请在数据源 AdditionalParams 中补充", dbService.Name)
906+
}
907+
databaseName := databaseNameParam.Value
908+
baseInfo.DefaultSchema = &databaseName
909+
}
910+
895911
return baseInfo, nil
896912
}
897913

@@ -946,7 +962,8 @@ func (sqlWorkbenchService *SqlWorkbenchService) buildUpdateDatasourceRequest(ctx
946962
func (sqlWorkbenchService *SqlWorkbenchService) convertDBType(dmsDBType string) string {
947963
// 这里需要根据实际的数据库类型映射关系进行转换
948964
// ODC ConnectType 枚举值: OB_MYSQL, OB_ORACLE, ORACLE, MYSQL, ODP_SHARDING_OB_MYSQL,
949-
// DORIS, POSTGRESQL, HIVE, DM, TIDB, SQL_SERVER, MONGODB, GAUSSDB 等
965+
// DORIS, POSTGRESQL, HIVE, DM, TIDB, SQL_SERVER, MONGODB, GAUSSDB, DB2 等
966+
// 其余调用创建数据源接口会直接失败
950967
switch dmsDBType {
951968
case "MySQL":
952969
return "MYSQL"
@@ -976,6 +993,8 @@ func (sqlWorkbenchService *SqlWorkbenchService) convertDBType(dmsDBType string)
976993
return "MYSQL"
977994
case "MongoDB":
978995
return "MONGODB"
996+
case "DB2":
997+
return "DB2"
979998
default:
980999
return dmsDBType
9811000
}

internal/sql_workbench/service/sql_workbench_service_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package sql_workbench
22

33
import (
4+
"strings"
45
"testing"
56

67
"github.com/actiontech/dms/internal/dms/biz"
@@ -27,6 +28,7 @@ func Test_convertDBType(t *testing.T) {
2728
"PolarDB For MySQL": {input: "PolarDB For MySQL", expected: "MYSQL"},
2829
"GaussDB": {input: "GaussDB", expected: "GAUSSDB"},
2930
"MongoDB": {input: "MongoDB", expected: "MONGODB"},
31+
"DB2": {input: "DB2", expected: "DB2"},
3032
"Unknown passthrough": {input: "UnknownDB", expected: "UnknownDB"},
3133
}
3234
for name, tc := range cases {
@@ -58,6 +60,7 @@ func Test_SupportDBType(t *testing.T) {
5860
"PolarDB For MySQL supported": {input: pkgConst.DBTypePolarDBForMySQL, expected: true},
5961
"GaussDB supported": {input: pkgConst.DBTypeGaussDB, expected: true},
6062
"GaussDBForMySQL unsupported": {input: pkgConst.DBTypeGaussDBForMySQL, expected: false},
63+
"DB2 unsupported": {input: pkgConst.DBTypeDB2, expected: false},
6164
"empty string unsupported": {input: pkgConst.DBType(""), expected: false},
6265
"unknown type unsupported": {input: pkgConst.DBType("UnknownDBType"), expected: false},
6366
}
@@ -137,3 +140,105 @@ func Test_buildMongoDatasourceOptions_tlsOnly(t *testing.T) {
137140
}
138141
}
139142

143+
// Test_buildDatasourceBaseInfo_DB2 覆盖 buildDatasourceBaseInfo 中 DB2 / 回归 4 组 case:
144+
//
145+
// (a) DB2 正例:AdditionalParam database_name=testdb → baseInfo.DefaultSchema=="testdb"
146+
// (b) DB2 负例:缺 database_name → 返回 err 且 err 含 "database_name"
147+
// (c) MySQL 回归:DefaultSchema == nil 且无 err
148+
// (d) Oracle 回归:ServiceName != nil 且无 err
149+
//
150+
// 通过 fillDatasourceBaseInfo(无 IO helper)进行 mock-only 单测,避免触达 projectUsecase / DB。
151+
func Test_buildDatasourceBaseInfo_DB2(t *testing.T) {
152+
svc := &SqlWorkbenchService{}
153+
const envID = int64(1)
154+
const datasourceName = "proj:ds"
155+
156+
cases := map[string]struct {
157+
dbService *biz.DBService
158+
expectErr bool
159+
expectErrSubstr string
160+
expectDefaultSchema *string
161+
expectServiceName *string
162+
}{
163+
"DB2 happy path": {
164+
dbService: &biz.DBService{
165+
Name: "db2-1",
166+
DBType: "DB2",
167+
AdditionalParams: pkgParams.Params{
168+
{Key: "database_name", Value: "testdb"},
169+
},
170+
},
171+
expectErr: false,
172+
expectDefaultSchema: strPtr("testdb"),
173+
expectServiceName: nil,
174+
},
175+
"DB2 missing database_name": {
176+
dbService: &biz.DBService{
177+
Name: "db2-2",
178+
DBType: "DB2",
179+
AdditionalParams: pkgParams.Params{},
180+
},
181+
expectErr: true,
182+
expectErrSubstr: "database_name",
183+
},
184+
"MySQL regression": {
185+
dbService: &biz.DBService{
186+
Name: "mysql-1",
187+
DBType: "MySQL",
188+
AdditionalParams: pkgParams.Params{},
189+
},
190+
expectErr: false,
191+
expectDefaultSchema: nil,
192+
expectServiceName: nil,
193+
},
194+
"Oracle regression": {
195+
dbService: &biz.DBService{
196+
Name: "oracle-1",
197+
DBType: "Oracle",
198+
AdditionalParams: pkgParams.Params{
199+
{Key: "service_name", Value: "ORCL"},
200+
},
201+
},
202+
expectErr: false,
203+
expectDefaultSchema: nil,
204+
expectServiceName: strPtr("ORCL"),
205+
},
206+
}
207+
208+
for name, tc := range cases {
209+
t.Run(name, func(t *testing.T) {
210+
got, err := svc.fillDatasourceBaseInfo(datasourceName, tc.dbService, envID)
211+
if tc.expectErr {
212+
if err == nil {
213+
t.Fatalf("expected error, got nil; baseInfo=%+v", got)
214+
}
215+
if tc.expectErrSubstr != "" && !strings.Contains(err.Error(), tc.expectErrSubstr) {
216+
t.Errorf("error %q does not contain %q", err.Error(), tc.expectErrSubstr)
217+
}
218+
return
219+
}
220+
if err != nil {
221+
t.Fatalf("unexpected error: %v", err)
222+
}
223+
if got == nil {
224+
t.Fatalf("expected non-nil baseInfo")
225+
}
226+
// DefaultSchema 对比
227+
if (got.DefaultSchema == nil) != (tc.expectDefaultSchema == nil) {
228+
t.Errorf("DefaultSchema nil mismatch: got=%v, want=%v", got.DefaultSchema, tc.expectDefaultSchema)
229+
} else if got.DefaultSchema != nil && tc.expectDefaultSchema != nil && *got.DefaultSchema != *tc.expectDefaultSchema {
230+
t.Errorf("DefaultSchema = %q, want %q", *got.DefaultSchema, *tc.expectDefaultSchema)
231+
}
232+
// ServiceName 对比
233+
if (got.ServiceName == nil) != (tc.expectServiceName == nil) {
234+
t.Errorf("ServiceName nil mismatch: got=%v, want=%v", got.ServiceName, tc.expectServiceName)
235+
} else if got.ServiceName != nil && tc.expectServiceName != nil && *got.ServiceName != *tc.expectServiceName {
236+
t.Errorf("ServiceName = %q, want %q", *got.ServiceName, *tc.expectServiceName)
237+
}
238+
})
239+
}
240+
}
241+
242+
func strPtr(s string) *string {
243+
return &s
244+
}

0 commit comments

Comments
 (0)