Skip to content

Commit b65835c

Browse files
Let any user access anonymous AI messages if they have the ID (#9245)
1 parent c8a554b commit b65835c

2 files changed

Lines changed: 79 additions & 10 deletions

File tree

runtime/ai/ai.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,17 @@ func (r *Runner) Session(ctx context.Context, opts *SessionOptions) (res *Sessio
113113
if err != nil {
114114
return nil, fmt.Errorf("failed to find session %q: %w", opts.SessionID, err)
115115
}
116-
retrieveUntilMessageID := session.SharedUntilMessageID
117-
if session.OwnerID == opts.Claims.UserID || opts.Claims.SkipChecks {
118-
// If the user owns the session or skipCheck enabled, they can see all messages.
119-
retrieveUntilMessageID = ""
116+
117+
// Check access: you can access anonymous sessions, your own sessions, and shared sessions.
118+
// For shared sessions, if you are not the owner, you can only see messages up to the SharedUntilMessageID (inclusive).
119+
// For sessions without an owner (unauthenticated users using a public project), we don't check access and rely on security by obscurity (generally a decent trade-off, but specifically introduced to get citation links over MCP working for unauthenticated demos).
120+
// It's important to respect SkipChecks to ensure access in Rill Developer (where auth is disabled, but SkipChecks is true).
121+
var retrieveUntilMessageID string
122+
if session.OwnerID != "" && session.OwnerID != opts.Claims.UserID && !opts.Claims.SkipChecks {
123+
if session.SharedUntilMessageID == "" {
124+
return nil, fmt.Errorf("access denied to session %q", session.ID)
125+
}
126+
retrieveUntilMessageID = session.SharedUntilMessageID
120127
}
121128

122129
ms, err := catalog.FindAIMessages(ctx, opts.SessionID)
@@ -158,12 +165,6 @@ func (r *Runner) Session(ctx context.Context, opts *SessionOptions) (res *Sessio
158165
}
159166
}
160167

161-
// Check access: for now, only allow users to access their own sessions or shared sessions with trimmed messages.
162-
// Checking !SkipChecks to ensure access for superusers and for Rill Developer (where auth is disabled and SkipChecks is true).
163-
if opts.Claims.UserID != session.OwnerID && !opts.Claims.SkipChecks && session.SharedUntilMessageID == "" {
164-
return nil, fmt.Errorf("access denied to session %q", session.ID)
165-
}
166-
167168
// Setup logger
168169
logger := r.Runtime.Logger.Named("ai").With(
169170
zap.String("instance_id", opts.InstanceID),

runtime/server/chat_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,74 @@ measures:
319319
require.Len(t, get3.Messages, len(res4.Messages))
320320
}
321321

322+
func TestAnonymousSessionAccess(t *testing.T) {
323+
rt, instanceID := testruntime.NewInstance(t)
324+
srv, err := server.NewServer(context.Background(), &server.Options{}, rt, zap.NewNop(), ratelimit.NewNoop(), activity.NewNoopClient(), nil)
325+
require.NoError(t, err)
326+
327+
// Claims for an anonymous and authenticated user.
328+
anonClaims := &runtime.SecurityClaims{
329+
UserID: "",
330+
Permissions: []runtime.Permission{runtime.ReadObjects, runtime.ReadMetrics, runtime.UseAI},
331+
}
332+
authedClaims := &runtime.SecurityClaims{
333+
UserID: "foo",
334+
Permissions: []runtime.Permission{runtime.ReadObjects, runtime.ReadMetrics, runtime.UseAI},
335+
}
336+
337+
// Create an anonymous session with a message directly to avoid LLM calls and keep the test cheap.
338+
runner := ai.NewRunner(rt, activity.NewNoopClient())
339+
sess, err := runner.Session(t.Context(), &ai.SessionOptions{
340+
InstanceID: instanceID,
341+
CreateIfNotExists: true,
342+
Claims: anonClaims,
343+
UserAgent: "test",
344+
})
345+
require.NoError(t, err)
346+
msg, err := sess.CallTool(t.Context(), ai.RoleUser, ai.ListMetricsViewsName, nil, &ai.ListMetricsViewsArgs{})
347+
require.NoError(t, err)
348+
err = sess.Flush(t.Context())
349+
require.NoError(t, err)
350+
351+
// Anonymous user can access the anonymous session
352+
anonCtx := auth.WithClaims(t.Context(), anonClaims)
353+
convRes, err := srv.GetConversation(anonCtx, &runtimev1.GetConversationRequest{
354+
InstanceId: instanceID,
355+
ConversationId: msg.Call.SessionID,
356+
})
357+
require.NoError(t, err)
358+
require.NotNil(t, convRes.Conversation)
359+
require.Equal(t, msg.Call.SessionID, convRes.Conversation.Id)
360+
361+
// Anonymous user can access the message in the anonymous session
362+
msgRes, err := srv.GetAIMessage(anonCtx, &runtimev1.GetAIMessageRequest{
363+
InstanceId: instanceID,
364+
ConversationId: msg.Call.SessionID,
365+
MessageId: msg.Call.ID,
366+
})
367+
require.NoError(t, err)
368+
require.Equal(t, msg.Call.ID, msgRes.Message.Id)
369+
370+
// Authenticated user can also access the anonymous session
371+
authedCtx := auth.WithClaims(t.Context(), authedClaims)
372+
convRes, err = srv.GetConversation(authedCtx, &runtimev1.GetConversationRequest{
373+
InstanceId: instanceID,
374+
ConversationId: msg.Call.SessionID,
375+
})
376+
require.NoError(t, err)
377+
require.NotNil(t, convRes.Conversation)
378+
require.Equal(t, msg.Call.SessionID, convRes.Conversation.Id)
379+
380+
// Authenticated user can access the message in the anonymous session
381+
msgRes, err = srv.GetAIMessage(authedCtx, &runtimev1.GetAIMessageRequest{
382+
InstanceId: instanceID,
383+
ConversationId: msg.Call.SessionID,
384+
MessageId: msg.Call.ID,
385+
})
386+
require.NoError(t, err)
387+
require.Equal(t, msg.Call.ID, msgRes.Message.Id)
388+
}
389+
322390
func TestAgentChoiceAndContext(t *testing.T) {
323391
// Skip in CI since we make real LLM calls.
324392
testmode.Expensive(t)

0 commit comments

Comments
 (0)