Skip to content

Commit 6976950

Browse files
feat(pi): add native Engram memory tools
1 parent b094052 commit 6976950

13 files changed

Lines changed: 909 additions & 261 deletions

File tree

DOCS.md

Lines changed: 203 additions & 110 deletions
Large diffs are not rendered by default.

internal/mcp/mcp.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,12 +1786,8 @@ func handleJudge(s *store.Store, activity *SessionActivity) server.ToolHandlerFu
17861786
}
17871787
var confidence *float64
17881788
if v, ok := req.GetArguments()["confidence"].(float64); ok {
1789-
// Clamp to [0, 1] per design §6.3.
1790-
if v < 0 {
1791-
v = 0
1792-
}
1793-
if v > 1 {
1794-
v = 1
1789+
if v < 0 || v > 1 {
1790+
return mcp.NewToolResultError("confidence must be between 0.0 and 1.0"), nil
17951791
}
17961792
confidence = &v
17971793
}
@@ -1862,12 +1858,8 @@ func handleCompare(s *store.Store, _ *SessionActivity) server.ToolHandlerFunc {
18621858
if !okConf {
18631859
return mcp.NewToolResultError("confidence is required (float 0.0..1.0)"), nil
18641860
}
1865-
// Clamp to [0, 1].
1866-
if rawConf < 0 {
1867-
rawConf = 0
1868-
}
1869-
if rawConf > 1 {
1870-
rawConf = 1
1861+
if rawConf < 0 || rawConf > 1 {
1862+
return mcp.NewToolResultError("confidence must be between 0.0 and 1.0"), nil
18711863
}
18721864

18731865
// --- optional model ---

internal/server/server.go

Lines changed: 274 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import (
1818
"strings"
1919
"time"
2020

21+
"github.com/Gentleman-Programming/engram/internal/diagnostic"
22+
projectpkg "github.com/Gentleman-Programming/engram/internal/project"
2123
"github.com/Gentleman-Programming/engram/internal/store"
2224
)
2325

@@ -173,10 +175,12 @@ func (s *Server) routes() {
173175
s.mux.HandleFunc("GET /export", s.handleExport)
174176
s.mux.HandleFunc("POST /import", s.handleImport)
175177

176-
// Stats
178+
// Stats / diagnostics
177179
s.mux.HandleFunc("GET /stats", s.handleStats)
180+
s.mux.HandleFunc("GET /doctor", s.handleDoctor)
178181

179-
// Project migration
182+
// Project detection / migration
183+
s.mux.HandleFunc("GET /project/current", s.handleCurrentProject)
180184
s.mux.HandleFunc("POST /projects/migrate", s.handleMigrateProject)
181185

182186
// Sync status (degraded-state visibility for autosync)
@@ -187,6 +191,8 @@ func (s *Server) routes() {
187191
s.mux.HandleFunc("GET /conflicts/stats", s.handleConflictsStats)
188192
s.mux.HandleFunc("GET /conflicts/deferred", s.handleListDeferred)
189193
s.mux.HandleFunc("POST /conflicts/scan", s.handleScanConflicts)
194+
s.mux.HandleFunc("POST /conflicts/judge", s.handleJudgeConflict)
195+
s.mux.HandleFunc("POST /conflicts/compare", s.handleCompareMemories)
190196
s.mux.HandleFunc("POST /conflicts/deferred/replay", s.handleReplayDeferred)
191197
s.mux.HandleFunc("GET /conflicts/{relation_id}", s.handleGetConflict)
192198
}
@@ -279,6 +285,9 @@ func (s *Server) handleAddObservation(w http.ResponseWriter, r *http.Request) {
279285
jsonError(w, http.StatusBadRequest, "session_id, title, and content are required")
280286
return
281287
}
288+
if !s.validateSessionProject(w, body.SessionID, body.Project) {
289+
return
290+
}
282291

283292
id, err := s.store.AddObservation(body)
284293
if err != nil {
@@ -300,6 +309,9 @@ func (s *Server) handlePassiveCapture(w http.ResponseWriter, r *http.Request) {
300309
jsonError(w, http.StatusBadRequest, "session_id is required")
301310
return
302311
}
312+
if !s.validateSessionProject(w, body.SessionID, body.Project) {
313+
return
314+
}
303315

304316
result, err := s.store.PassiveCapture(body)
305317
if err != nil {
@@ -466,6 +478,9 @@ func (s *Server) handleAddPrompt(w http.ResponseWriter, r *http.Request) {
466478
jsonError(w, http.StatusBadRequest, "session_id and content are required")
467479
return
468480
}
481+
if !s.validateSessionProject(w, body.SessionID, body.Project) {
482+
return
483+
}
469484

470485
id, err := s.store.AddPrompt(body)
471486
if err != nil {
@@ -636,6 +651,102 @@ func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
636651
jsonResponse(w, http.StatusOK, stats)
637652
}
638653

654+
func (s *Server) handleDoctor(w http.ResponseWriter, r *http.Request) {
655+
projectName := strings.TrimSpace(r.URL.Query().Get("project"))
656+
if projectName != "" {
657+
projectName, _ = store.NormalizeProject(projectName)
658+
exists, err := s.store.ProjectExists(projectName)
659+
if err != nil {
660+
jsonError(w, http.StatusInternalServerError, err.Error())
661+
return
662+
}
663+
if !exists {
664+
available, err := s.store.ListProjectNames()
665+
if err != nil {
666+
jsonError(w, http.StatusInternalServerError, err.Error())
667+
return
668+
}
669+
jsonErrorWithFields(w, http.StatusNotFound, fmt.Sprintf("project %q not found", projectName), map[string]any{
670+
"code": "unknown_project",
671+
"available_projects": available,
672+
})
673+
return
674+
}
675+
} else {
676+
cwd := strings.TrimSpace(r.URL.Query().Get("cwd"))
677+
if cwd == "" {
678+
var err error
679+
cwd, err = os.Getwd()
680+
if err != nil {
681+
jsonError(w, http.StatusInternalServerError, "failed to detect cwd: "+err.Error())
682+
return
683+
}
684+
}
685+
res := projectpkg.DetectProjectFull(cwd)
686+
if res.Error != nil {
687+
code := "project_detection_failed"
688+
if len(res.AvailableProjects) > 0 {
689+
code = "ambiguous_project"
690+
}
691+
jsonErrorWithFields(w, http.StatusBadRequest, "project detection failed: "+res.Error.Error(), map[string]any{
692+
"code": code,
693+
"available_projects": res.AvailableProjects,
694+
})
695+
return
696+
}
697+
projectName, _ = store.NormalizeProject(res.Project)
698+
}
699+
700+
check := strings.TrimSpace(r.URL.Query().Get("check"))
701+
runner := diagnostic.NewRunner()
702+
scope := diagnostic.Scope{Store: s.store, Project: projectName, Now: time.Now()}
703+
var (
704+
report diagnostic.Report
705+
err error
706+
)
707+
if check != "" {
708+
report, err = runner.RunOne(r.Context(), scope, check)
709+
} else {
710+
report, err = runner.RunAll(r.Context(), scope)
711+
}
712+
if err != nil {
713+
report = diagnostic.ErrorReport(projectName, err)
714+
}
715+
716+
jsonResponse(w, http.StatusOK, report)
717+
}
718+
719+
// ─── Project Detection ───────────────────────────────────────────────────────
720+
721+
func (s *Server) handleCurrentProject(w http.ResponseWriter, r *http.Request) {
722+
cwd := strings.TrimSpace(r.URL.Query().Get("cwd"))
723+
if cwd == "" {
724+
var err error
725+
cwd, err = os.Getwd()
726+
if err != nil {
727+
jsonError(w, http.StatusInternalServerError, "failed to detect cwd: "+err.Error())
728+
return
729+
}
730+
}
731+
732+
res := projectpkg.DetectProjectFull(cwd)
733+
payload := map[string]any{
734+
"project": res.Project,
735+
"project_source": res.Source,
736+
"project_path": res.Path,
737+
"cwd": cwd,
738+
"available_projects": res.AvailableProjects,
739+
}
740+
if res.Warning != "" {
741+
payload["warning"] = res.Warning
742+
}
743+
if res.Error != nil {
744+
payload["error_hint"] = res.Error.Error()
745+
}
746+
747+
jsonResponse(w, http.StatusOK, payload)
748+
}
749+
639750
// ─── Sync Status ─────────────────────────────────────────────────────────────
640751

641752
func (s *Server) handleSyncStatus(w http.ResponseWriter, r *http.Request) {
@@ -967,6 +1078,133 @@ func (s *Server) handleScanConflicts(w http.ResponseWriter, r *http.Request) {
9671078
jsonResponse(w, http.StatusOK, resp)
9681079
}
9691080

1081+
// handleJudgeConflict serves POST /conflicts/judge.
1082+
// Body: {"judgment_id":"rel-...","relation":"related|compatible|scoped|conflicts_with|supersedes|not_conflict", ...}
1083+
func (s *Server) handleJudgeConflict(w http.ResponseWriter, r *http.Request) {
1084+
var body struct {
1085+
JudgmentID string `json:"judgment_id"`
1086+
Relation string `json:"relation"`
1087+
Reason string `json:"reason"`
1088+
Evidence string `json:"evidence"`
1089+
Confidence *float64 `json:"confidence"`
1090+
SessionID string `json:"session_id"`
1091+
}
1092+
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
1093+
jsonError(w, http.StatusBadRequest, "invalid json: "+err.Error())
1094+
return
1095+
}
1096+
if strings.TrimSpace(body.JudgmentID) == "" {
1097+
jsonError(w, http.StatusBadRequest, "judgment_id is required")
1098+
return
1099+
}
1100+
if strings.TrimSpace(body.Relation) == "" {
1101+
jsonError(w, http.StatusBadRequest, "relation is required")
1102+
return
1103+
}
1104+
1105+
var reason *string
1106+
if body.Reason != "" {
1107+
reason = &body.Reason
1108+
}
1109+
var evidence *string
1110+
if body.Evidence != "" {
1111+
evidence = &body.Evidence
1112+
}
1113+
if body.Confidence != nil && (*body.Confidence < 0 || *body.Confidence > 1) {
1114+
jsonError(w, http.StatusBadRequest, "confidence must be between 0.0 and 1.0")
1115+
return
1116+
}
1117+
1118+
relation, err := s.store.JudgeRelation(store.JudgeRelationParams{
1119+
JudgmentID: body.JudgmentID,
1120+
Relation: body.Relation,
1121+
Reason: reason,
1122+
Evidence: evidence,
1123+
Confidence: body.Confidence,
1124+
MarkedByActor: "agent",
1125+
MarkedByKind: "agent",
1126+
SessionID: body.SessionID,
1127+
})
1128+
if err != nil {
1129+
jsonError(w, http.StatusBadRequest, err.Error())
1130+
return
1131+
}
1132+
1133+
s.notifyWrite()
1134+
jsonResponse(w, http.StatusOK, map[string]any{"relation": relation})
1135+
}
1136+
1137+
// handleCompareMemories serves POST /conflicts/compare.
1138+
// Body: {"memory_id_a":1,"memory_id_b":2,"relation":"related", "confidence":0.9, "reasoning":"..."}
1139+
func (s *Server) handleCompareMemories(w http.ResponseWriter, r *http.Request) {
1140+
var body struct {
1141+
MemoryIDA int64 `json:"memory_id_a"`
1142+
MemoryIDB int64 `json:"memory_id_b"`
1143+
Relation string `json:"relation"`
1144+
Confidence *float64 `json:"confidence"`
1145+
Reasoning string `json:"reasoning"`
1146+
Model string `json:"model"`
1147+
}
1148+
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
1149+
jsonError(w, http.StatusBadRequest, "invalid json: "+err.Error())
1150+
return
1151+
}
1152+
if body.MemoryIDA == 0 {
1153+
jsonError(w, http.StatusBadRequest, "memory_id_a is required")
1154+
return
1155+
}
1156+
if body.MemoryIDB == 0 {
1157+
jsonError(w, http.StatusBadRequest, "memory_id_b is required")
1158+
return
1159+
}
1160+
if strings.TrimSpace(body.Relation) == "" {
1161+
jsonError(w, http.StatusBadRequest, "relation is required")
1162+
return
1163+
}
1164+
if strings.TrimSpace(body.Reasoning) == "" {
1165+
jsonError(w, http.StatusBadRequest, "reasoning is required")
1166+
return
1167+
}
1168+
if body.Confidence == nil {
1169+
jsonError(w, http.StatusBadRequest, "confidence is required")
1170+
return
1171+
}
1172+
confidence := *body.Confidence
1173+
if confidence < 0 || confidence > 1 {
1174+
jsonError(w, http.StatusBadRequest, "confidence must be between 0.0 and 1.0")
1175+
return
1176+
}
1177+
1178+
obsA, err := s.store.GetObservation(body.MemoryIDA)
1179+
if err != nil {
1180+
jsonError(w, http.StatusNotFound, fmt.Sprintf("observation id=%d not found: %s", body.MemoryIDA, err))
1181+
return
1182+
}
1183+
obsB, err := s.store.GetObservation(body.MemoryIDB)
1184+
if err != nil {
1185+
jsonError(w, http.StatusNotFound, fmt.Sprintf("observation id=%d not found: %s", body.MemoryIDB, err))
1186+
return
1187+
}
1188+
1189+
syncID, err := s.store.JudgeBySemantic(store.JudgeBySemanticParams{
1190+
SourceID: obsA.SyncID,
1191+
TargetID: obsB.SyncID,
1192+
Relation: body.Relation,
1193+
Confidence: confidence,
1194+
Reasoning: body.Reasoning,
1195+
Model: body.Model,
1196+
})
1197+
if err != nil {
1198+
jsonError(w, http.StatusBadRequest, err.Error())
1199+
return
1200+
}
1201+
1202+
if syncID != "" {
1203+
s.notifyWrite()
1204+
}
1205+
jsonResponse(w, http.StatusOK, map[string]any{"sync_id": syncID})
1206+
}
1207+
9701208
// handleReplayDeferred serves POST /conflicts/deferred/replay
9711209
func (s *Server) handleReplayDeferred(w http.ResponseWriter, r *http.Request) {
9721210
result, err := s.store.ReplayDeferred()
@@ -1019,6 +1257,32 @@ func (s *Server) handleGetConflict(w http.ResponseWriter, r *http.Request) {
10191257

10201258
// ─── Helpers ─────────────────────────────────────────────────────────────────
10211259

1260+
func (s *Server) validateSessionProject(w http.ResponseWriter, sessionID, projectName string) bool {
1261+
if strings.TrimSpace(projectName) == "" {
1262+
return true
1263+
}
1264+
projectName, _ = store.NormalizeProject(projectName)
1265+
session, err := s.store.GetSession(sessionID)
1266+
if err != nil {
1267+
if errors.Is(err, sql.ErrNoRows) {
1268+
jsonError(w, http.StatusNotFound, "session not found")
1269+
return false
1270+
}
1271+
jsonError(w, http.StatusInternalServerError, err.Error())
1272+
return false
1273+
}
1274+
sessionProject, _ := store.NormalizeProject(session.Project)
1275+
if sessionProject != "" && sessionProject != projectName {
1276+
jsonErrorWithFields(w, http.StatusBadRequest, "session project does not match requested project", map[string]any{
1277+
"code": "session_project_mismatch",
1278+
"session_project": sessionProject,
1279+
"project": projectName,
1280+
})
1281+
return false
1282+
}
1283+
return true
1284+
}
1285+
10221286
func jsonResponse(w http.ResponseWriter, status int, data any) {
10231287
w.Header().Set("Content-Type", "application/json")
10241288
w.WriteHeader(status)
@@ -1029,6 +1293,14 @@ func jsonError(w http.ResponseWriter, status int, msg string) {
10291293
jsonResponse(w, status, map[string]string{"error": msg})
10301294
}
10311295

1296+
func jsonErrorWithFields(w http.ResponseWriter, status int, msg string, fields map[string]any) {
1297+
payload := map[string]any{"error": msg}
1298+
for key, value := range fields {
1299+
payload[key] = value
1300+
}
1301+
jsonResponse(w, status, payload)
1302+
}
1303+
10321304
func queryInt(r *http.Request, key string, defaultVal int) int {
10331305
v := r.URL.Query().Get(key)
10341306
if v == "" {

0 commit comments

Comments
 (0)