Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 48 additions & 7 deletions internal/cloud/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,18 @@ var ErrInvalidDashboardSessionToken = errors.New("invalid dashboard session toke
var ErrProjectNotAllowed = errors.New("project is not allowed for this token")

type Service struct {
store *cloudstore.CloudStore
expectedToken string
dashboardAuth map[string]struct{}
allowed map[string]struct{}
jwtSecret []byte
now func() time.Time
store *cloudstore.CloudStore
expectedToken string
dashboardAuth map[string]struct{}
allowed map[string]struct{}
allowedAll bool
jwtSecret []byte
now func() time.Time
}

type ProjectScopeAuthorizer struct {
allowed map[string]struct{}
allowed map[string]struct{}
allowedAll bool
}

func NewService(store *cloudstore.CloudStore, jwtSecret string) (*Service, error) {
Expand Down Expand Up @@ -153,7 +155,12 @@ func (s *Service) SetDashboardSessionTokens(tokens []string) {

func (s *Service) SetAllowedProjects(projects []string) {
s.allowed = make(map[string]struct{})
s.allowedAll = false
for _, project := range projects {
if strings.TrimSpace(project) == "*" {
s.allowedAll = true
return
}
normalized, _ := store.NormalizeProject(project)
normalized = strings.TrimSpace(normalized)
if normalized == "" {
Expand All @@ -164,22 +171,42 @@ func (s *Service) SetAllowedProjects(projects []string) {
}

func (s *Service) AuthorizeProject(project string) error {
if s.allowedAll {
normalized, _ := store.NormalizeProject(project)
normalized = strings.TrimSpace(normalized)
if normalized == "" {
return fmt.Errorf("project is required")
}
return nil
}
return authorizeProjectAgainstAllowlist(project, s.allowed)
}

// EnrolledProjects returns the sorted list of projects that this Service is
// authorized to serve. Used by cloudserver's mutation pull to filter mutations
// to the caller's enrolled projects (REQ-202).
//
// When the wildcard "*" is configured, nil is returned to signal "no project
// filter" — callers must treat nil as "allow all" (matching the ListMutationsSince
// nil-means-all contract).
//
// The interface is cloudserver.EnrolledProjectsProvider; this method makes
// *Service satisfy it without importing cloudserver (structural assertion).
func (s *Service) EnrolledProjects() []string {
if s.allowedAll {
return nil
}
return sortedAllowlist(s.allowed)
}

func (a *ProjectScopeAuthorizer) SetAllowedProjects(projects []string) {
a.allowed = make(map[string]struct{})
a.allowedAll = false
for _, project := range projects {
if strings.TrimSpace(project) == "*" {
a.allowedAll = true
return
}
normalized, _ := store.NormalizeProject(project)
normalized = strings.TrimSpace(normalized)
if normalized == "" {
Expand All @@ -190,14 +217,28 @@ func (a *ProjectScopeAuthorizer) SetAllowedProjects(projects []string) {
}

func (a *ProjectScopeAuthorizer) AuthorizeProject(project string) error {
if a.allowedAll {
normalized, _ := store.NormalizeProject(project)
normalized = strings.TrimSpace(normalized)
if normalized == "" {
return fmt.Errorf("project is required")
}
return nil
}
return authorizeProjectAgainstAllowlist(project, a.allowed)
}

// EnrolledProjects returns the sorted list of projects this authorizer allows.
// Matches the cloudserver.EnrolledProjectsProvider contract so mutation pull
// can filter server-side by the caller's enrolled projects (REQ-202) rather
// than fail-closing to an empty result set.
//
// When the wildcard "*" is configured, nil is returned to signal "no project
// filter" (matching the ListMutationsSince nil-means-all contract).
func (a *ProjectScopeAuthorizer) EnrolledProjects() []string {
if a.allowedAll {
return nil
}
return sortedAllowlist(a.allowed)
}

Expand Down
61 changes: 61 additions & 0 deletions internal/cloud/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,67 @@ func TestAuthorizeBearerTokenConstantTimeComparison(t *testing.T) {
}
}

// TestAuthorizeProjectWildcard tests that a single "*" in the allowlist permits any project.
func TestAuthorizeProjectWildcard(t *testing.T) {
svc, err := NewService(&cloudstore.CloudStore{}, strings.Repeat("x", 32))
if err != nil {
t.Fatalf("new service: %v", err)
}

// "*" alone must allow any project.
svc.SetAllowedProjects([]string{"*"})
if err := svc.AuthorizeProject("any-project"); err != nil {
t.Fatalf("wildcard allowlist must permit any project, got %v", err)
}
if err := svc.AuthorizeProject("ANOTHER-ONE"); err != nil {
t.Fatalf("wildcard allowlist must permit uppercased project, got %v", err)
}
if err := svc.AuthorizeProject("team-foo"); err != nil {
t.Fatalf("wildcard allowlist must permit prefixed project, got %v", err)
}
}

// TestAuthorizeProjectWildcardMixedWithExact tests that "*" in a mixed list still allows all.
func TestAuthorizeProjectWildcardMixedWithExact(t *testing.T) {
svc, err := NewService(&cloudstore.CloudStore{}, strings.Repeat("x", 32))
if err != nil {
t.Fatalf("new service: %v", err)
}

svc.SetAllowedProjects([]string{"proj-a", "*"})
if err := svc.AuthorizeProject("anything-at-all"); err != nil {
t.Fatalf("wildcard in mixed list must still permit any project, got %v", err)
}
}

// TestProjectScopeAuthorizerWildcard tests that NewProjectScopeAuthorizer also respects "*".
func TestProjectScopeAuthorizerWildcard(t *testing.T) {
authorizer := NewProjectScopeAuthorizer([]string{"*"})
if err := authorizer.AuthorizeProject("any-project"); err != nil {
t.Fatalf("wildcard authorizer must permit any project, got %v", err)
}
if err := authorizer.AuthorizeProject("team-foo"); err != nil {
t.Fatalf("wildcard authorizer must permit team-prefixed project, got %v", err)
}
}

// TestAuthorizeProjectExactMatchStillWorksAfterWildcardChange verifies backward compatibility.
func TestAuthorizeProjectExactMatchStillWorksAfterWildcardChange(t *testing.T) {
svc, err := NewService(&cloudstore.CloudStore{}, strings.Repeat("x", 32))
if err != nil {
t.Fatalf("new service: %v", err)
}

// Exact allowlist: only listed projects pass.
svc.SetAllowedProjects([]string{"proj-a", "proj-b"})
if err := svc.AuthorizeProject("proj-a"); err != nil {
t.Fatalf("exact match must still be allowed, got %v", err)
}
if err := svc.AuthorizeProject("proj-c"); !errors.Is(err, ErrProjectNotAllowed) {
t.Fatalf("unlisted project must be rejected, got %v", err)
}
}

func TestDashboardSessionTokenSupportsAdditionalDashboardCredential(t *testing.T) {
svc, err := NewService(&cloudstore.CloudStore{}, strings.Repeat("x", 32))
if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions internal/cloud/cloudstore/cloudstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
type CloudStore struct {
db *sql.DB
dashboardAllowedScopes map[string]struct{}
dashboardAllowedAll bool
dashboardReadModelMu sync.RWMutex
dashboardReadModel dashboardReadModel
dashboardReadModelOK bool
Expand Down Expand Up @@ -65,9 +66,16 @@ func (cs *CloudStore) SetDashboardAllowedProjects(projects []string) {
if cs == nil {
return
}
cs.dashboardAllowedAll = false
cs.dashboardAllowedScopes = make(map[string]struct{})
for _, project := range projects {
project = strings.TrimSpace(project)
if project == "*" {
cs.dashboardAllowedAll = true
cs.dashboardAllowedScopes = nil
cs.invalidateDashboardReadModel()
return
}
if project == "" {
continue
}
Expand Down
15 changes: 11 additions & 4 deletions internal/cloud/cloudstore/dashboard_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -734,9 +734,13 @@ func applyDashboardMutation(
}

func (m dashboardReadModel) scoped(allowed map[string]struct{}) dashboardReadModel {
// Empty map or wildcard sentinel "*" means no filtering.
if len(allowed) == 0 {
return m
}
if _, ok := allowed["*"]; ok {
return m
}
projects := make([]DashboardProjectRow, 0, len(m.projects))
projectDetails := make(map[string]DashboardProjectDetail)
totalChunks := 0
Expand Down Expand Up @@ -1047,6 +1051,9 @@ func (cs *CloudStore) normalizeDashboardProject(project string) (string, error)
if project == "" {
return "", fmt.Errorf("%w", ErrDashboardProjectInvalid)
}
if cs.dashboardAllowedAll {
return project, nil
}
if len(cs.dashboardAllowedScopes) > 0 {
if _, ok := cs.dashboardAllowedScopes[project]; !ok {
return "", fmt.Errorf("%w", ErrDashboardProjectForbidden)
Expand Down Expand Up @@ -1095,7 +1102,7 @@ func (cs *CloudStore) loadChunkRows(project string) ([]dashboardChunkRow, error)
project = strings.TrimSpace(project)
query := `SELECT chunk_id, project_name, created_by, created_at, payload FROM cloud_chunks`
args := []any{}
if project == "" && len(cs.dashboardAllowedScopes) > 0 {
if project == "" && !cs.dashboardAllowedAll && len(cs.dashboardAllowedScopes) > 0 {
allowed := make([]string, 0, len(cs.dashboardAllowedScopes))
for name := range cs.dashboardAllowedScopes {
allowed = append(allowed, name)
Expand All @@ -1105,7 +1112,7 @@ func (cs *CloudStore) loadChunkRows(project string) ([]dashboardChunkRow, error)
args = append(args, allowed)
}
if project != "" {
if len(cs.dashboardAllowedScopes) > 0 {
if !cs.dashboardAllowedAll && len(cs.dashboardAllowedScopes) > 0 {
if _, ok := cs.dashboardAllowedScopes[project]; !ok {
return []dashboardChunkRow{}, nil
}
Expand Down Expand Up @@ -1152,7 +1159,7 @@ func (cs *CloudStore) loadMutationRows(project string) ([]dashboardMutationRow,
project = strings.TrimSpace(project)
query := `SELECT seq, project, entity, entity_key, op, payload::text, occurred_at FROM cloud_mutations`
args := []any{}
if project == "" && len(cs.dashboardAllowedScopes) > 0 {
if project == "" && !cs.dashboardAllowedAll && len(cs.dashboardAllowedScopes) > 0 {
allowed := make([]string, 0, len(cs.dashboardAllowedScopes))
for name := range cs.dashboardAllowedScopes {
allowed = append(allowed, name)
Expand All @@ -1162,7 +1169,7 @@ func (cs *CloudStore) loadMutationRows(project string) ([]dashboardMutationRow,
args = append(args, allowed)
}
if project != "" {
if len(cs.dashboardAllowedScopes) > 0 {
if !cs.dashboardAllowedAll && len(cs.dashboardAllowedScopes) > 0 {
if _, ok := cs.dashboardAllowedScopes[project]; !ok {
return []dashboardMutationRow{}, nil
}
Expand Down
71 changes: 71 additions & 0 deletions internal/cloud/cloudstore/wildcard_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package cloudstore

import (
"errors"
"testing"
"time"
)

// TestScopedWildcardPassesAllProjects verifies that a wildcard allowlist does not
// filter the dashboard read model — all projects must survive scoped().
func TestScopedWildcardPassesAllProjects(t *testing.T) {
t1 := time.Date(2026, 4, 23, 10, 0, 0, 0, time.UTC)
chunks := []dashboardChunkRow{
{chunkID: "c1", project: "team-alpha", createdBy: "alice", createdAt: t1,
parsed: parseMustChunk(t, []byte(`{"sessions":[{"id":"s1","project":"team-alpha","started_at":"2026-04-23T08:00:00Z"}],"observations":[],"prompts":[]}`))},
{chunkID: "c2", project: "team-beta", createdBy: "bob", createdAt: t1,
parsed: parseMustChunk(t, []byte(`{"sessions":[{"id":"s2","project":"team-beta","started_at":"2026-04-23T09:00:00Z"}],"observations":[],"prompts":[]}`))},
{chunkID: "c3", project: "other-project", createdBy: "charlie", createdAt: t1,
parsed: parseMustChunk(t, []byte(`{"sessions":[{"id":"s3","project":"other-project","started_at":"2026-04-23T10:00:00Z"}],"observations":[],"prompts":[]}`))},
}

model, err := buildDashboardReadModel(chunks)
if err != nil {
t.Fatalf("buildDashboardReadModel: %v", err)
}

// Wildcard "*" map — represents the wildcard sentinel.
wildcard := map[string]struct{}{"*": {}}
scoped := model.scoped(wildcard)
if len(scoped.projects) != 3 {
t.Fatalf("wildcard allowlist must pass all 3 projects through scoped(), got %d: %v", len(scoped.projects), scoped.projects)
}
}

// TestScopedWithExactAllowlist ensures that scoped() with an explicit list
// (no wildcard) filters the dashboard correctly — backward compatibility guard.
func TestScopedWithExactAllowlist(t *testing.T) {
t1 := time.Date(2026, 4, 23, 10, 0, 0, 0, time.UTC)
chunks := []dashboardChunkRow{
{chunkID: "c1", project: "team-alpha", createdBy: "alice", createdAt: t1,
parsed: parseMustChunk(t, []byte(`{"sessions":[{"id":"s1","project":"team-alpha","started_at":"2026-04-23T08:00:00Z"}],"observations":[],"prompts":[]}`))},
{chunkID: "c2", project: "team-beta", createdBy: "bob", createdAt: t1,
parsed: parseMustChunk(t, []byte(`{"sessions":[{"id":"s2","project":"team-beta","started_at":"2026-04-23T09:00:00Z"}],"observations":[],"prompts":[]}`))},
}

model, err := buildDashboardReadModel(chunks)
if err != nil {
t.Fatalf("buildDashboardReadModel: %v", err)
}

// Explicit list: only "team-alpha" must survive scoped().
scoped := model.scoped(map[string]struct{}{"team-alpha": {}})
if len(scoped.projects) != 1 || scoped.projects[0].Project != "team-alpha" {
t.Fatalf("exact allowlist must keep only team-alpha, got %v", scoped.projects)
}
}

// TestNormalizeDashboardProjectWildcardAllowsAnyProject verifies that with wildcard
// set, any project passes normalizeDashboardProject (no ErrDashboardProjectForbidden).
func TestNormalizeDashboardProjectWildcardAllowsAnyProject(t *testing.T) {
cs := &CloudStore{}
cs.SetDashboardAllowedProjects([]string{"*"})

_, err := cs.normalizeDashboardProject("any-project")
if err != nil && !errors.Is(err, ErrDashboardProjectNotFound) {
// ErrDashboardProjectNotFound is fine (no DB) — ErrDashboardProjectForbidden is not.
if errors.Is(err, ErrDashboardProjectForbidden) {
t.Fatalf("wildcard allowlist must not forbid any project, got %v", err)
}
}
}
Loading