Skip to content

Commit 7f5e17e

Browse files
committed
fix(rpc): fallback to DB when session is offline for task and context queries
Session sweep removes dead sessions from memory, causing all task/context RPC handlers to fail with "Session ID not found". Add DB+disk fallback for read-only task handlers, context save operations, and session recovery in SpiteStream so that offline sessions no longer block data retrieval.
1 parent dce3134 commit 7f5e17e

4 files changed

Lines changed: 321 additions & 16 deletions

File tree

server/rpc/rpc-context.go

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"os"
88

9+
"github.com/chainreactors/IoM-go/client"
910
"github.com/chainreactors/IoM-go/consts"
1011
"github.com/chainreactors/IoM-go/proto/client/clientpb"
1112
errs "github.com/chainreactors/IoM-go/types"
@@ -69,7 +70,8 @@ func getTaskFromContext(req *clientpb.Context) (*core.Task, error) {
6970

7071
sess, err := core.Sessions.Get(sessionID)
7172
if err != nil {
72-
return nil, err
73+
// Session not in memory (dead/offline), fall back to DB.
74+
return getTaskFromDB(sessionID, req.Task.TaskId)
7375
}
7476

7577
task := sess.Tasks.Get(req.Task.TaskId)
@@ -80,6 +82,37 @@ func getTaskFromContext(req *clientpb.Context) (*core.Task, error) {
8082
return task, nil
8183
}
8284

85+
// getTaskFromDB constructs a minimal core.Task from DB when session is not in memory.
86+
func getTaskFromDB(sessionID string, taskID uint32) (*core.Task, error) {
87+
dbTask, err := db.GetTaskBySessionAndSeq(sessionID, taskID)
88+
if err != nil {
89+
return nil, errs.ErrNotFoundTask
90+
}
91+
task := core.FromTaskProtobuf(dbTask.ToProtobuf())
92+
// Build a minimal Session so that context handlers (HandleScreenshot, etc.)
93+
// can resolve the file path and bindResolvedTask can produce session metadata.
94+
task.Session = &core.Session{
95+
ID: sessionID,
96+
SessionContext: &client.SessionContext{},
97+
}
98+
// Try to enrich from DB; non-fatal if it fails.
99+
if dbSess, err := db.FindSession(sessionID); err == nil && dbSess != nil {
100+
task.Session.Target = dbSess.Target
101+
task.Session.PipelineID = dbSess.PipelineID
102+
task.Session.ListenerID = dbSess.ListenerID
103+
task.Session.Name = dbSess.ProfileName
104+
task.Session.Note = dbSess.Note
105+
task.Session.Group = dbSess.GroupName
106+
task.Session.Type = dbSess.Type
107+
task.Session.RawID = dbSess.RawID
108+
task.Session.CreatedAt = dbSess.CreatedAt
109+
if dbSess.Data != nil {
110+
task.Session.SessionContext = dbSess.Data
111+
}
112+
}
113+
return task, nil
114+
}
115+
83116
func bindResolvedTask(req *clientpb.Context, task *core.Task) *clientpb.Context {
84117
if req == nil || task == nil || task.Session == nil {
85118
return req

server/rpc/rpc-listener.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/chainreactors/IoM-go/types"
1111
"github.com/chainreactors/logs"
1212
"github.com/chainreactors/malice-network/server/internal/core"
13+
"github.com/chainreactors/malice-network/server/internal/db"
1314
"google.golang.org/protobuf/proto"
1415
"io"
1516
"time"
@@ -67,8 +68,20 @@ func (rpc *Server) SpiteStream(stream listenerrpc.ListenerRPC_SpiteStreamServer)
6768

6869
sess, err := core.Sessions.Get(msg.SessionId)
6970
if err != nil {
70-
logs.Log.Warnf("session %s not found", msg.SessionId)
71-
continue
71+
// Session not in memory — try to recover from DB since the implant
72+
// is actively sending data, meaning it's still alive.
73+
dbSess, dbErr := db.FindSession(msg.SessionId)
74+
if dbErr != nil || dbSess == nil {
75+
logs.Log.Warnf("session %s not found in memory or DB", msg.SessionId)
76+
continue
77+
}
78+
sess, err = core.RecoverSession(dbSess)
79+
if err != nil {
80+
logs.Log.Warnf("session %s recovery failed: %v", msg.SessionId, err)
81+
continue
82+
}
83+
core.Sessions.Add(sess)
84+
logs.Log.Importantf("session %s recovered from DB via SpiteStream", msg.SessionId)
7285
}
7386
sess.SetLastCheckin(time.Now().Unix())
7487
if sess.MarkAlive() {

server/rpc/rpc-task.go

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,66 @@ func getAllTaskContextsFromDisk(sess *core.Session, task *core.Task) (*clientpb.
139139
}, nil
140140
}
141141

142+
// getTaskContextFromDB fetches task context from DB + disk when session is not in memory.
143+
func getTaskContextFromDB(sessionID string, taskID uint32, index int32) (*clientpb.TaskContext, error) {
144+
dbSess, err := db.FindSession(sessionID)
145+
if err != nil || dbSess == nil {
146+
return nil, types.ErrNotFoundSession
147+
}
148+
dbTask, err := db.GetTaskBySessionAndSeq(sessionID, taskID)
149+
if err != nil {
150+
return nil, types.ErrNotFoundTask
151+
}
152+
entries, err := readTaskSpitesFromDisk(sessionID, dbTask.Seq)
153+
if err != nil || len(entries) == 0 {
154+
return nil, types.ErrNotFoundTaskContent
155+
}
156+
var spite *implantpb.Spite
157+
if index == -1 {
158+
spite = entries[len(entries)-1].Spite
159+
} else {
160+
for _, e := range entries {
161+
if e.Index == int(index) {
162+
spite = e.Spite
163+
break
164+
}
165+
}
166+
}
167+
if spite == nil {
168+
return nil, types.ErrNotFoundTaskContent
169+
}
170+
return &clientpb.TaskContext{
171+
Task: dbTask.ToProtobuf(),
172+
Session: dbSess.ToProtobuf(),
173+
Spite: spite,
174+
}, nil
175+
}
176+
177+
// getAllTaskContextFromDB fetches all task contexts from DB + disk when session is not in memory.
178+
func getAllTaskContextFromDB(sessionID string, taskID uint32) (*clientpb.TaskContexts, error) {
179+
dbSess, err := db.FindSession(sessionID)
180+
if err != nil || dbSess == nil {
181+
return nil, types.ErrNotFoundSession
182+
}
183+
dbTask, err := db.GetTaskBySessionAndSeq(sessionID, taskID)
184+
if err != nil {
185+
return nil, types.ErrNotFoundTask
186+
}
187+
entries, err := readTaskSpitesFromDisk(sessionID, dbTask.Seq)
188+
if err != nil || len(entries) == 0 {
189+
return nil, types.ErrNotFoundTaskContent
190+
}
191+
spites := make([]*implantpb.Spite, 0, len(entries))
192+
for _, e := range entries {
193+
spites = append(spites, e.Spite)
194+
}
195+
return &clientpb.TaskContexts{
196+
Task: dbTask.ToProtobuf(),
197+
Session: dbSess.ToProtobuf(),
198+
Spites: spites,
199+
}, nil
200+
}
201+
142202
func getTaskContext(sess *core.Session, task *core.Task, index int32) (*clientpb.TaskContext, error) {
143203
var msg *implantpb.Spite
144204
var ok bool
@@ -174,7 +234,12 @@ func (rpc *Server) GetTasks(ctx context.Context, req *clientpb.TaskRequest) (*cl
174234
} else {
175235
sess, err := core.Sessions.Get(req.SessionId)
176236
if err != nil {
177-
return nil, types.ErrNotFoundSession
237+
// Fallback to DB when session is not in memory (e.g., dead session)
238+
modelTasks, dbErr := db.ListTasksBySession(req.SessionId)
239+
if dbErr != nil {
240+
return nil, types.ErrNotFoundSession
241+
}
242+
return modelTasks.ToProtobuf(), nil
178243
}
179244
return sess.Tasks.ToProtobuf(), nil
180245
}
@@ -198,7 +263,7 @@ func (rpc *Server) GetTaskContent(ctx context.Context, req *clientpb.Task) (*cli
198263
}
199264
sess, err := core.Sessions.Get(req.SessionId)
200265
if err != nil {
201-
return nil, types.ErrNotFoundSession
266+
return getTaskContextFromDB(req.SessionId, req.TaskId, req.Need)
202267
}
203268
task := sess.Tasks.GetOrRecover(sess, req.TaskId)
204269
if task == nil {
@@ -217,7 +282,8 @@ func (rpc *Server) WaitTaskContent(ctx context.Context, req *clientpb.Task) (*cl
217282
}
218283
sess, err := core.Sessions.Get(req.SessionId)
219284
if err != nil {
220-
return nil, types.ErrNotFoundSession
285+
// Session not in memory (dead), try DB+disk directly
286+
return getTaskContextFromDB(req.SessionId, req.TaskId, req.Need)
221287
}
222288
task := sess.Tasks.GetOrRecover(sess, req.TaskId)
223289
if task == nil {
@@ -269,7 +335,8 @@ func (rpc *Server) WaitTaskFinish(ctx context.Context, req *clientpb.Task) (*cli
269335
}
270336
sess, err := core.Sessions.Get(req.SessionId)
271337
if err != nil {
272-
return nil, types.ErrNotFoundSession
338+
// Session not in memory (dead), try DB+disk directly
339+
return getTaskContextFromDB(req.SessionId, req.TaskId, -1)
273340
}
274341
task := sess.Tasks.GetOrRecover(sess, req.TaskId)
275342
if task == nil {
@@ -309,7 +376,7 @@ func (rpc *Server) GetAllTaskContent(ctx context.Context, req *clientpb.Task) (*
309376
}
310377
sess, err := core.Sessions.Get(req.SessionId)
311378
if err != nil {
312-
return nil, types.ErrNotFoundSession
379+
return getAllTaskContextFromDB(req.SessionId, req.TaskId)
313380
}
314381
task := sess.Tasks.GetOrRecover(sess, req.TaskId)
315382
if task == nil {

0 commit comments

Comments
 (0)