11package sql_workbench
22
33import (
4+ "strings"
45 "testing"
56
67 "github.com/actiontech/dms/internal/dms/biz"
@@ -27,6 +28,7 @@ func Test_convertDBType(t *testing.T) {
2728 "PolarDB For MySQL" : {input : "PolarDB For MySQL" , expected : "MYSQL" },
2829 "GaussDB" : {input : "GaussDB" , expected : "GAUSSDB" },
2930 "MongoDB" : {input : "MongoDB" , expected : "MONGODB" },
31+ "DB2" : {input : "DB2" , expected : "DB2" },
3032 "Unknown passthrough" : {input : "UnknownDB" , expected : "UnknownDB" },
3133 }
3234 for name , tc := range cases {
@@ -58,6 +60,7 @@ func Test_SupportDBType(t *testing.T) {
5860 "PolarDB For MySQL supported" : {input : pkgConst .DBTypePolarDBForMySQL , expected : true },
5961 "GaussDB supported" : {input : pkgConst .DBTypeGaussDB , expected : true },
6062 "GaussDBForMySQL unsupported" : {input : pkgConst .DBTypeGaussDBForMySQL , expected : false },
63+ "DB2 unsupported" : {input : pkgConst .DBTypeDB2 , expected : false },
6164 "empty string unsupported" : {input : pkgConst .DBType ("" ), expected : false },
6265 "unknown type unsupported" : {input : pkgConst .DBType ("UnknownDBType" ), expected : false },
6366 }
@@ -137,3 +140,105 @@ func Test_buildMongoDatasourceOptions_tlsOnly(t *testing.T) {
137140 }
138141}
139142
143+ // Test_buildDatasourceBaseInfo_DB2 覆盖 buildDatasourceBaseInfo 中 DB2 / 回归 4 组 case:
144+ //
145+ // (a) DB2 正例:AdditionalParam database_name=testdb → baseInfo.DefaultSchema=="testdb"
146+ // (b) DB2 负例:缺 database_name → 返回 err 且 err 含 "database_name"
147+ // (c) MySQL 回归:DefaultSchema == nil 且无 err
148+ // (d) Oracle 回归:ServiceName != nil 且无 err
149+ //
150+ // 通过 fillDatasourceBaseInfo(无 IO helper)进行 mock-only 单测,避免触达 projectUsecase / DB。
151+ func Test_buildDatasourceBaseInfo_DB2 (t * testing.T ) {
152+ svc := & SqlWorkbenchService {}
153+ const envID = int64 (1 )
154+ const datasourceName = "proj:ds"
155+
156+ cases := map [string ]struct {
157+ dbService * biz.DBService
158+ expectErr bool
159+ expectErrSubstr string
160+ expectDefaultSchema * string
161+ expectServiceName * string
162+ }{
163+ "DB2 happy path" : {
164+ dbService : & biz.DBService {
165+ Name : "db2-1" ,
166+ DBType : "DB2" ,
167+ AdditionalParams : pkgParams.Params {
168+ {Key : "database_name" , Value : "testdb" },
169+ },
170+ },
171+ expectErr : false ,
172+ expectDefaultSchema : strPtr ("testdb" ),
173+ expectServiceName : nil ,
174+ },
175+ "DB2 missing database_name" : {
176+ dbService : & biz.DBService {
177+ Name : "db2-2" ,
178+ DBType : "DB2" ,
179+ AdditionalParams : pkgParams.Params {},
180+ },
181+ expectErr : true ,
182+ expectErrSubstr : "database_name" ,
183+ },
184+ "MySQL regression" : {
185+ dbService : & biz.DBService {
186+ Name : "mysql-1" ,
187+ DBType : "MySQL" ,
188+ AdditionalParams : pkgParams.Params {},
189+ },
190+ expectErr : false ,
191+ expectDefaultSchema : nil ,
192+ expectServiceName : nil ,
193+ },
194+ "Oracle regression" : {
195+ dbService : & biz.DBService {
196+ Name : "oracle-1" ,
197+ DBType : "Oracle" ,
198+ AdditionalParams : pkgParams.Params {
199+ {Key : "service_name" , Value : "ORCL" },
200+ },
201+ },
202+ expectErr : false ,
203+ expectDefaultSchema : nil ,
204+ expectServiceName : strPtr ("ORCL" ),
205+ },
206+ }
207+
208+ for name , tc := range cases {
209+ t .Run (name , func (t * testing.T ) {
210+ got , err := svc .fillDatasourceBaseInfo (datasourceName , tc .dbService , envID )
211+ if tc .expectErr {
212+ if err == nil {
213+ t .Fatalf ("expected error, got nil; baseInfo=%+v" , got )
214+ }
215+ if tc .expectErrSubstr != "" && ! strings .Contains (err .Error (), tc .expectErrSubstr ) {
216+ t .Errorf ("error %q does not contain %q" , err .Error (), tc .expectErrSubstr )
217+ }
218+ return
219+ }
220+ if err != nil {
221+ t .Fatalf ("unexpected error: %v" , err )
222+ }
223+ if got == nil {
224+ t .Fatalf ("expected non-nil baseInfo" )
225+ }
226+ // DefaultSchema 对比
227+ if (got .DefaultSchema == nil ) != (tc .expectDefaultSchema == nil ) {
228+ t .Errorf ("DefaultSchema nil mismatch: got=%v, want=%v" , got .DefaultSchema , tc .expectDefaultSchema )
229+ } else if got .DefaultSchema != nil && tc .expectDefaultSchema != nil && * got .DefaultSchema != * tc .expectDefaultSchema {
230+ t .Errorf ("DefaultSchema = %q, want %q" , * got .DefaultSchema , * tc .expectDefaultSchema )
231+ }
232+ // ServiceName 对比
233+ if (got .ServiceName == nil ) != (tc .expectServiceName == nil ) {
234+ t .Errorf ("ServiceName nil mismatch: got=%v, want=%v" , got .ServiceName , tc .expectServiceName )
235+ } else if got .ServiceName != nil && tc .expectServiceName != nil && * got .ServiceName != * tc .expectServiceName {
236+ t .Errorf ("ServiceName = %q, want %q" , * got .ServiceName , * tc .expectServiceName )
237+ }
238+ })
239+ }
240+ }
241+
242+ func strPtr (s string ) * string {
243+ return & s
244+ }
0 commit comments