Skip to content

Commit 2543c18

Browse files
authored
Add tests for SSE client connection and transport handling (#15)
- Introduced unit tests for SSE protocol connection scenarios in client_test.go. - Enhanced error handling assertions to ensure connection failures are correctly identified. - Updated client.go to support SSE transport creation with optional headers. - Improved transport handling logic for different protocols, including HTTP and streamable-HTTP.
1 parent 84118d5 commit 2543c18

2 files changed

Lines changed: 301 additions & 25 deletions

File tree

internal/upstream/client.go

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
const (
2626
transportHTTP = "http"
2727
transportStreamableHTTP = "streamable-http"
28+
transportSSE = "sse"
2829
transportStdio = "stdio"
2930
osWindows = "windows"
3031
)
@@ -162,33 +163,57 @@ func (c *Client) Connect(ctx context.Context) error {
162163

163164
switch transportType {
164165
case transportHTTP, transportStreamableHTTP:
165-
httpTransport, err := transport.NewStreamableHTTP(c.config.URL)
166-
if err != nil {
167-
c.mu.Lock()
168-
c.lastError = err
169-
c.retryCount++
170-
c.lastRetryTime = time.Now()
171-
c.mu.Unlock()
172-
return fmt.Errorf("failed to create HTTP transport: %w", err)
166+
// Create streamable HTTP transport with headers if provided
167+
if len(c.config.Headers) > 0 {
168+
httpTransport, err := transport.NewStreamableHTTP(c.config.URL,
169+
transport.WithHTTPHeaders(c.config.Headers))
170+
if err != nil {
171+
c.mu.Lock()
172+
c.lastError = err
173+
c.retryCount++
174+
c.lastRetryTime = time.Now()
175+
c.mu.Unlock()
176+
return fmt.Errorf("failed to create HTTP transport: %w", err)
177+
}
178+
c.client = client.NewClient(httpTransport)
179+
} else {
180+
httpTransport, err := transport.NewStreamableHTTP(c.config.URL)
181+
if err != nil {
182+
c.mu.Lock()
183+
c.lastError = err
184+
c.retryCount++
185+
c.lastRetryTime = time.Now()
186+
c.mu.Unlock()
187+
return fmt.Errorf("failed to create HTTP transport: %w", err)
188+
}
189+
c.client = client.NewClient(httpTransport)
173190
}
174-
c.client = client.NewClient(httpTransport)
175-
case "sse":
176-
// For SSE, we need to handle Cloudflare's two-step connection pattern
177-
// First connect to /sse to get session info, then use that for actual communication
178-
c.logger.Debug("Creating SSE transport with Cloudflare compatibility",
179-
zap.String("url", c.config.URL))
180-
181-
// Create SSE transport with special handling for Cloudflare endpoints
182-
httpTransport, err := transport.NewStreamableHTTP(c.config.URL)
183-
if err != nil {
184-
c.mu.Lock()
185-
c.lastError = err
186-
c.retryCount++
187-
c.lastRetryTime = time.Now()
188-
c.mu.Unlock()
189-
return fmt.Errorf("failed to create SSE transport: %w", err)
191+
case transportSSE:
192+
// Create SSE client with headers if provided
193+
if len(c.config.Headers) > 0 {
194+
sseClient, err := client.NewSSEMCPClient(c.config.URL,
195+
client.WithHeaders(c.config.Headers))
196+
if err != nil {
197+
c.mu.Lock()
198+
c.lastError = err
199+
c.retryCount++
200+
c.lastRetryTime = time.Now()
201+
c.mu.Unlock()
202+
return fmt.Errorf("failed to create SSE client: %w", err)
203+
}
204+
c.client = sseClient
205+
} else {
206+
sseClient, err := client.NewSSEMCPClient(c.config.URL)
207+
if err != nil {
208+
c.mu.Lock()
209+
c.lastError = err
210+
c.retryCount++
211+
c.lastRetryTime = time.Now()
212+
c.mu.Unlock()
213+
return fmt.Errorf("failed to create SSE client: %w", err)
214+
}
215+
c.client = sseClient
190216
}
191-
c.client = client.NewClient(httpTransport)
192217
case transportStdio:
193218
var originalCommand string
194219
var originalArgs []string

internal/upstream/client_test.go

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
package upstream
2+
3+
import (
4+
"context"
5+
"strings"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
"go.uber.org/zap"
12+
13+
"mcpproxy-go/internal/config"
14+
)
15+
16+
func TestClient_Connect_SSE_NotSupported(t *testing.T) {
17+
// Create a test config with SSE protocol
18+
cfg := &config.ServerConfig{
19+
Name: "test-sse-server",
20+
URL: "http://localhost:8080/sse",
21+
Protocol: "sse",
22+
Enabled: true,
23+
Created: time.Now(),
24+
}
25+
26+
// Create test logger
27+
logger, err := zap.NewDevelopment()
28+
require.NoError(t, err)
29+
30+
// Create client with all required parameters
31+
client, err := NewClient("test-client", cfg, logger, nil, nil)
32+
require.NoError(t, err)
33+
require.NotNil(t, client)
34+
35+
// Attempt to connect - should fail at connection, not at transport creation
36+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
37+
defer cancel()
38+
39+
err = client.Connect(ctx)
40+
41+
// Verify the error is about connection failure, not SSE not supported
42+
require.Error(t, err)
43+
assert.NotContains(t, err.Error(), "SSE transport is not supported")
44+
// Should be a connection error since there's no actual SSE server
45+
assert.True(t,
46+
strings.Contains(err.Error(), "connection") ||
47+
strings.Contains(err.Error(), "dial") ||
48+
strings.Contains(err.Error(), "refused") ||
49+
strings.Contains(err.Error(), "timeout"),
50+
"Error should be about connection failure, not SSE support")
51+
}
52+
53+
func TestClient_DetermineTransportType_SSE(t *testing.T) {
54+
cfg := &config.ServerConfig{
55+
Protocol: "sse",
56+
URL: "http://localhost:8080/sse",
57+
}
58+
59+
logger, err := zap.NewDevelopment()
60+
require.NoError(t, err)
61+
62+
client, err := NewClient("test-client", cfg, logger, nil, nil)
63+
require.NoError(t, err)
64+
65+
// Test that determineTransportType returns "sse" for SSE protocol
66+
transportType := client.determineTransportType()
67+
assert.Equal(t, "sse", transportType)
68+
}
69+
70+
func TestClient_Connect_SSE_ErrorContainsAlternatives(t *testing.T) {
71+
cfg := &config.ServerConfig{
72+
Name: "test-sse-server",
73+
URL: "http://localhost:8080/sse",
74+
Protocol: "sse",
75+
Enabled: true,
76+
Created: time.Now(),
77+
}
78+
79+
logger, err := zap.NewDevelopment()
80+
require.NoError(t, err)
81+
82+
client, err := NewClient("test-client", cfg, logger, nil, nil)
83+
require.NoError(t, err)
84+
85+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
86+
defer cancel()
87+
88+
err = client.Connect(ctx)
89+
90+
require.Error(t, err)
91+
92+
// Verify that the error is about connection failure, not SSE not supported
93+
errorMsg := err.Error()
94+
assert.NotContains(t, errorMsg, "SSE transport is not supported")
95+
assert.NotContains(t, errorMsg, "streamable-http")
96+
97+
// Should be a connection error since there's no actual SSE server
98+
assert.True(t,
99+
strings.Contains(errorMsg, "connection") ||
100+
strings.Contains(errorMsg, "dial") ||
101+
strings.Contains(errorMsg, "refused") ||
102+
strings.Contains(errorMsg, "timeout"),
103+
"Error should be about connection failure, not SSE support")
104+
}
105+
106+
func TestClient_Connect_WorkingTransports(t *testing.T) {
107+
tests := []struct {
108+
name string
109+
protocol string
110+
url string
111+
command string
112+
args []string
113+
shouldConnect bool
114+
errorContains string
115+
}{
116+
{
117+
name: "SSE protocol should work (until actual connection)",
118+
protocol: "sse",
119+
url: "http://localhost:8080/sse",
120+
shouldConnect: false, // Will fail at actual connection, but transport creation should work
121+
errorContains: "", // Won't check error for SSE as it depends on server availability
122+
},
123+
{
124+
name: "HTTP protocol should work (until actual connection)",
125+
protocol: "http",
126+
url: "http://localhost:8080",
127+
shouldConnect: false, // Will fail at actual connection, but transport creation should work
128+
errorContains: "", // Won't check error for HTTP as it depends on server availability
129+
},
130+
{
131+
name: "Streamable-HTTP protocol should work (until actual connection)",
132+
protocol: "streamable-http",
133+
url: "http://localhost:8080",
134+
shouldConnect: false, // Will fail at actual connection, but transport creation should work
135+
errorContains: "", // Won't check error for streamable-http as it depends on server availability
136+
},
137+
}
138+
139+
for _, tt := range tests {
140+
t.Run(tt.name, func(t *testing.T) {
141+
cfg := &config.ServerConfig{
142+
Name: "test-server",
143+
Protocol: tt.protocol,
144+
URL: tt.url,
145+
Command: tt.command,
146+
Args: tt.args,
147+
Enabled: true,
148+
Created: time.Now(),
149+
}
150+
151+
logger, err := zap.NewDevelopment()
152+
require.NoError(t, err)
153+
154+
client, err := NewClient("test-client", cfg, logger, nil, nil)
155+
require.NoError(t, err)
156+
157+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
158+
defer cancel()
159+
160+
err = client.Connect(ctx)
161+
162+
if tt.shouldConnect {
163+
assert.NoError(t, err)
164+
} else if tt.errorContains != "" {
165+
require.Error(t, err)
166+
assert.Contains(t, err.Error(), tt.errorContains)
167+
}
168+
})
169+
}
170+
}
171+
172+
func TestClient_Headers_Support(t *testing.T) {
173+
tests := []struct {
174+
name string
175+
protocol string
176+
url string
177+
headers map[string]string
178+
expectErr bool
179+
}{
180+
{
181+
name: "SSE with headers",
182+
protocol: "sse",
183+
url: "http://localhost:8080/sse",
184+
headers: map[string]string{
185+
"Authorization": "Bearer token123",
186+
"X-Custom": "custom-value",
187+
},
188+
expectErr: true, // Will fail at connection, but headers should be processed
189+
},
190+
{
191+
name: "Streamable-HTTP with headers",
192+
protocol: "streamable-http",
193+
url: "http://localhost:8080",
194+
headers: map[string]string{
195+
"Authorization": "Bearer token456",
196+
"Content-Type": "application/json",
197+
},
198+
expectErr: true, // Will fail at connection, but headers should be processed
199+
},
200+
{
201+
name: "SSE without headers",
202+
protocol: "sse",
203+
url: "http://localhost:8080/sse",
204+
headers: nil,
205+
expectErr: true, // Will fail at connection
206+
},
207+
{
208+
name: "Streamable-HTTP without headers",
209+
protocol: "streamable-http",
210+
url: "http://localhost:8080",
211+
headers: nil,
212+
expectErr: true, // Will fail at connection
213+
},
214+
}
215+
216+
for _, tt := range tests {
217+
t.Run(tt.name, func(t *testing.T) {
218+
cfg := &config.ServerConfig{
219+
Name: "test-headers-server",
220+
Protocol: tt.protocol,
221+
URL: tt.url,
222+
Headers: tt.headers,
223+
Enabled: true,
224+
Created: time.Now(),
225+
}
226+
227+
logger, err := zap.NewDevelopment()
228+
require.NoError(t, err)
229+
230+
client, err := NewClient("test-client", cfg, logger, nil, nil)
231+
require.NoError(t, err)
232+
require.NotNil(t, client)
233+
234+
// Test that headers are stored in config
235+
assert.Equal(t, tt.headers, client.config.Headers)
236+
237+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
238+
defer cancel()
239+
240+
err = client.Connect(ctx)
241+
242+
if tt.expectErr {
243+
require.Error(t, err)
244+
// Should not be a "not supported" error
245+
assert.NotContains(t, err.Error(), "not supported")
246+
} else {
247+
assert.NoError(t, err)
248+
}
249+
})
250+
}
251+
}

0 commit comments

Comments
 (0)