Skip to content

Commit e0a69ee

Browse files
authored
Therapy bug fixes (#236)
* Bug fixes * Bug fixes * Tests * Clean up
1 parent 61a7b34 commit e0a69ee

2 files changed

Lines changed: 118 additions & 46 deletions

File tree

internal/app/therapy.go

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"log"
99
"net/http"
1010
"os"
11+
"strings"
1112
"time"
1213

1314
"github.com/google/uuid"
@@ -64,22 +65,28 @@ func callTherapySessionEndpoint(text string, session *Session) *string {
6465
initReq.Header.Set("Authorization", "Bearer "+token)
6566
initReq.Header.Set("Content-Type", "application/json")
6667

67-
initResp, err := client.Do(initReq)
68-
if err != nil {
69-
log.Printf("[TherapySession] init request error: %v", err)
70-
return nil
71-
}
72-
func() {
73-
defer initResp.Body.Close()
74-
// Drain body for logging on non-2xx
75-
if initResp.StatusCode < 200 || initResp.StatusCode >= 300 {
76-
body, _ := io.ReadAll(initResp.Body)
77-
log.Printf("[TherapySession] init non-2xx: %d body=%s", initResp.StatusCode, string(body))
78-
}
79-
}()
80-
if initResp.StatusCode < 200 || initResp.StatusCode >= 300 {
81-
return nil
82-
}
68+
initResp, err := client.Do(initReq)
69+
if err != nil {
70+
log.Printf("[TherapySession] init request error: %v", err)
71+
return nil
72+
}
73+
defer initResp.Body.Close()
74+
proceed := false
75+
if initResp.StatusCode >= 200 && initResp.StatusCode < 300 {
76+
proceed = true
77+
} else {
78+
body, _ := io.ReadAll(initResp.Body)
79+
// Allow existing session scenario to proceed
80+
if initResp.StatusCode == 400 && strings.Contains(string(body), "Session already exists") {
81+
log.Printf("[TherapySession] init session exists, proceeding: %s", therapySessionID)
82+
proceed = true
83+
} else {
84+
log.Printf("[TherapySession] init non-2xx: %d body=%s", initResp.StatusCode, string(body))
85+
}
86+
}
87+
if !proceed {
88+
return nil
89+
}
8390

8491
// 2) Send user message via run_sse
8592
runURL := fmt.Sprintf("%s/run_sse", baseURL)
@@ -119,31 +126,50 @@ func callTherapySessionEndpoint(text string, session *Session) *string {
119126
log.Printf("[TherapySession] run non-2xx: %d body=%s", runResp.StatusCode, string(runRespBody))
120127
return nil
121128
}
122-
respStr := string(runRespBody)
123-
if respStr == "" {
124-
return nil
125-
}
126-
return &respStr
127-
}
129+
respStr := string(runRespBody)
130+
if respStr == "" {
131+
return nil
132+
}
128133

129-
func httpPostJSON(url string, payload string) (string, error) {
130-
//coverage:ignore
131-
req, err := http.NewRequest("POST", url, bytes.NewBuffer([]byte(payload)))
132-
if err != nil {
133-
return "", err
134-
}
135-
req.Header.Set("Content-Type", "application/json")
136-
client := &http.Client{Timeout: 15 * time.Second}
137-
resp, err := client.Do(req)
138-
if err != nil {
139-
return "", err
140-
}
141-
defer resp.Body.Close()
142-
body, err := io.ReadAll(resp.Body)
143-
if err != nil {
144-
return "", err
145-
}
146-
return string(body), nil
134+
// Try to extract plain text from JSON response
135+
// Support responses that are either raw JSON or lines prefixed with "data: "
136+
extractJSON := func(s string) string {
137+
s = strings.TrimSpace(s)
138+
if strings.HasPrefix(s, "data:") {
139+
// If multiple lines, pick the last data line
140+
lines := strings.Split(s, "\n")
141+
for i := len(lines) - 1; i >= 0; i-- {
142+
line := strings.TrimSpace(lines[i])
143+
if strings.HasPrefix(line, "data:") {
144+
return strings.TrimSpace(strings.TrimPrefix(line, "data:"))
145+
}
146+
}
147+
return strings.TrimSpace(strings.TrimPrefix(lines[len(lines)-1], "data:"))
148+
}
149+
return s
150+
}
151+
152+
type runSseContentPart struct {
153+
Text string `json:"text"`
154+
}
155+
type runSseContent struct {
156+
Parts []runSseContentPart `json:"parts"`
157+
}
158+
type runSseResponse struct {
159+
Content runSseContent `json:"content"`
160+
}
161+
162+
jsonCandidate := extractJSON(respStr)
163+
var parsed runSseResponse
164+
if err := json.Unmarshal([]byte(jsonCandidate), &parsed); err == nil {
165+
if len(parsed.Content.Parts) > 0 && parsed.Content.Parts[0].Text != "" {
166+
onlyText := parsed.Content.Parts[0].Text
167+
return &onlyText
168+
}
169+
}
170+
171+
// Fallback: return body as-is
172+
return &respStr
147173
}
148174

149175
// Relay a user message to the therapy session backend and append the reply

internal/app/therapy_test.go

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func TestEndTherapySession(t *testing.T) {
5555

5656
func TestRelayTherapyMessage(t *testing.T) {
5757
// Create a fake therapy session backend implementing both init and run endpoints
58-
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5959
token := r.Header.Get("Authorization")
6060
if token != "Bearer test-token" {
6161
http.Error(w, "unauthorized", http.StatusUnauthorized)
@@ -69,8 +69,9 @@ func TestRelayTherapyMessage(t *testing.T) {
6969
return
7070
case r.Method == http.MethodPost && r.URL.Path == "/run_sse":
7171
// Message sending endpoint
72-
w.WriteHeader(http.StatusOK)
73-
_, _ = w.Write([]byte("Hello, I'm here for you."))
72+
w.Header().Set("Content-Type", "text/event-stream")
73+
w.WriteHeader(http.StatusOK)
74+
_, _ = w.Write([]byte("data: {\"content\":{\"parts\":[{\"text\":\"Hello, I'm here for you.\"}],\"role\":\"model\"}}\n\n"))
7475
return
7576
default:
7677
http.NotFound(w, r)
@@ -114,7 +115,7 @@ func TestHandleSession_AutoEndWhenExpired(t *testing.T) {
114115

115116
func TestHandleSession_ForwardDuringActive(t *testing.T) {
116117
// Fake backend implementing both init and run endpoints
117-
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
118+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
118119
token := r.Header.Get("Authorization")
119120
if token != "Bearer test-token" {
120121
http.Error(w, "unauthorized", http.StatusUnauthorized)
@@ -126,8 +127,9 @@ func TestHandleSession_ForwardDuringActive(t *testing.T) {
126127
_, _ = w.Write([]byte(`{"ok":true}`))
127128
return
128129
case r.Method == http.MethodPost && r.URL.Path == "/run_sse":
129-
w.WriteHeader(http.StatusOK)
130-
_, _ = w.Write([]byte("Therapist reply"))
130+
w.Header().Set("Content-Type", "text/event-stream")
131+
w.WriteHeader(http.StatusOK)
132+
_, _ = w.Write([]byte("data: {\"content\":{\"parts\":[{\"text\":\"Therapist reply\"}],\"role\":\"model\"}}\n\n"))
131133
return
132134
default:
133135
http.NotFound(w, r)
@@ -155,6 +157,50 @@ func TestHandleSession_ForwardDuringActive(t *testing.T) {
155157
}
156158
}
157159

160+
func TestRelayTherapyMessage_ExistingSessionContinues(t *testing.T) {
161+
// Fake backend: init returns 400 Session already exists; run_sse returns a reply
162+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
163+
token := r.Header.Get("Authorization")
164+
if token != "Bearer test-token" {
165+
http.Error(w, "unauthorized", http.StatusUnauthorized)
166+
return
167+
}
168+
switch {
169+
case r.Method == http.MethodPost && strings.HasPrefix(r.URL.Path, "/apps/capymind_agent/users/u1/sessions/"):
170+
w.WriteHeader(http.StatusBadRequest)
171+
_, _ = w.Write([]byte(`{"detail":"Session already exists: abc-123"}`))
172+
return
173+
case r.Method == http.MethodPost && r.URL.Path == "/run_sse":
174+
w.Header().Set("Content-Type", "text/event-stream")
175+
w.WriteHeader(http.StatusOK)
176+
_, _ = w.Write([]byte("data: {\"content\":{\"parts\":[{\"text\":\"Hello again\"}],\"role\":\"model\"}}\n\n"))
177+
return
178+
default:
179+
http.NotFound(w, r)
180+
return
181+
}
182+
}))
183+
defer ts.Close()
184+
os.Setenv("CAPY_THERAPY_SESSION_URL", ts.URL)
185+
os.Setenv("CAPY_AGENT_TOKEN", "test-token")
186+
defer os.Unsetenv("CAPY_THERAPY_SESSION_URL")
187+
defer os.Unsetenv("CAPY_AGENT_TOKEN")
188+
189+
ctx := context.Background()
190+
locale := "en"
191+
user := &database.User{ID: "u1", Locale: &locale}
192+
session := createSession(&Job{Command: None}, user, nil, &ctx)
193+
194+
relayTherapyMessage("hi", session)
195+
196+
if len(session.Job.Output) == 0 {
197+
t.Fatalf("expected at least one output")
198+
}
199+
if session.Job.Output[0].TextID != "Hello again" {
200+
t.Fatalf("unexpected relay text: %s", session.Job.Output[0].TextID)
201+
}
202+
}
203+
158204
func TestHandleSession_EndOnOtherCommand(t *testing.T) {
159205
ctx := context.Background()
160206
future := time.Now().Add(5 * time.Minute)

0 commit comments

Comments
 (0)