Skip to content

Commit 25cfb2a

Browse files
committed
rebase
1 parent 1c77324 commit 25cfb2a

1 file changed

Lines changed: 73 additions & 21 deletions

File tree

tests/tool.go

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3561,7 +3561,11 @@ func RunMySQLListTablesMissingUniqueIndexes(t *testing.T, ctx context.Context, p
35613561
}
35623562
}
35633563

3564-
func RunMySQLListTableStatsTest(t *testing.T, ctx context.Context, pool *sql.DB, databaseName string, tableNameParam string, tableNameAuth string) {
3564+
func RunMySQLListTableStatsTest(t *testing.T, ctx context.Context, pool *sql.DB, databaseName string, tableNameParam string, tableNameAuth string, opts ...ToolExecOption) {
3565+
config := &ToolExecConfig{}
3566+
for _, opt := range opts {
3567+
opt(config)
3568+
}
35653569
type tableStatsDetails struct {
35663570
TableSchema string `json:"table_schema"`
35673571
TableName string `json:"table_name"`
@@ -3658,25 +3662,68 @@ func RunMySQLListTableStatsTest(t *testing.T, ctx context.Context, pool *sql.DB,
36583662

36593663
for _, tc := range invokeTcs {
36603664
t.Run(tc.name, func(t *testing.T) {
3661-
const api = "http://127.0.0.1:5000/api/tool/list_table_stats/invoke"
3662-
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
3663-
if resp.StatusCode != tc.wantStatusCode {
3664-
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
3665-
}
3666-
if tc.wantStatusCode != http.StatusOK {
3667-
return
3668-
}
3669-
3670-
var bodyWrapper struct {
3671-
Result json.RawMessage `json:"result"`
3672-
}
3673-
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
3674-
t.Fatalf("error decoding response wrapper: %v", err)
3675-
}
3676-
36773665
var resultString string
3678-
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
3679-
resultString = string(bodyWrapper.Result)
3666+
3667+
if config.isMCP {
3668+
reqBytes, _ := io.ReadAll(tc.requestBody)
3669+
var args map[string]any
3670+
if len(reqBytes) > 0 {
3671+
_ = json.Unmarshal(reqBytes, &args)
3672+
}
3673+
if args == nil {
3674+
args = make(map[string]any)
3675+
}
3676+
3677+
statusCode, mcpResp, err := InvokeMCPTool(t, "list_table_stats", args, nil)
3678+
3679+
// For the error case (expecting 500 in REST), we expect 200 OK in MCP with IsError=true
3680+
expectedStatus := tc.wantStatusCode
3681+
if tc.wantStatusCode == http.StatusInternalServerError {
3682+
expectedStatus = http.StatusOK
3683+
}
3684+
3685+
if statusCode != expectedStatus {
3686+
t.Fatalf("wrong status code: got %d, want %d, err: %v", statusCode, expectedStatus, err)
3687+
}
3688+
3689+
if tc.wantStatusCode == http.StatusInternalServerError {
3690+
if !mcpResp.Result.IsError {
3691+
t.Fatalf("expected error result for list_table_stats")
3692+
}
3693+
return // Error case, no need to check result body
3694+
}
3695+
3696+
if mcpResp.Result.IsError {
3697+
t.Fatalf("list_table_stats returned error result: %v", mcpResp.Result)
3698+
}
3699+
3700+
gotObj := getMCPResultText(t, mcpResp)
3701+
if len(gotObj) == 0 {
3702+
resultString = "null"
3703+
} else {
3704+
gotBytes, _ := json.Marshal(gotObj)
3705+
resultString = string(gotBytes)
3706+
}
3707+
} else {
3708+
const api = "http://127.0.0.1:5000/api/tool/list_table_stats/invoke"
3709+
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
3710+
if resp.StatusCode != tc.wantStatusCode {
3711+
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
3712+
}
3713+
if tc.wantStatusCode != http.StatusOK {
3714+
return
3715+
}
3716+
3717+
var bodyWrapper struct {
3718+
Result json.RawMessage `json:"result"`
3719+
}
3720+
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
3721+
t.Fatalf("error decoding response wrapper: %v", err)
3722+
}
3723+
3724+
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
3725+
resultString = string(bodyWrapper.Result)
3726+
}
36803727
}
36813728

36823729
var got any
@@ -3695,7 +3742,12 @@ func RunMySQLListTableStatsTest(t *testing.T, ctx context.Context, pool *sql.DB,
36953742
}
36963743
}
36973744

3698-
func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string) {
3745+
func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNameParam, tableNameAuth string, opts ...ToolExecOption) {
3746+
config := &ToolExecConfig{}
3747+
for _, opt := range opts {
3748+
opt(config)
3749+
}
3750+
36993751
type tableFragmentationDetails struct {
37003752
TableSchema string `json:"table_schema"`
37013753
TableName string `json:"table_name"`
@@ -3775,7 +3827,7 @@ func RunMySQLListTableFragmentationTest(t *testing.T, databaseName, tableNamePar
37753827
t.Run(tc.name, func(t *testing.T) {
37763828
var resultString string
37773829

3778-
if config.IsMCP {
3830+
if config.isMCP {
37793831
reqBytes, _ := io.ReadAll(tc.requestBody)
37803832
var args map[string]any
37813833
if len(reqBytes) > 0 {

0 commit comments

Comments
 (0)