Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 303 additions & 0 deletions internal/engines/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2069,4 +2069,307 @@ var _ = Describe("check-engine", func() {
}
})
})

// DEPTH CHECK SAMPLE (3-level deep check)
depthCheckSchema := `
entity user {}

entity bottom {
relation member @user
permission check = member
}

entity middle {
relation parent @bottom
permission check = parent.check
}

entity top {
relation parent @middle
permission check = parent.check
}
`

Context("Depth Check Sample: Check", func() {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did these 2 tests fail, once written, before the fixes were made in invoke.go and utils.go?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before the changes, the case that should pass was failing. i added these tests because this scenario specifically requires a depth of 3. i included both failing and passing cases to verify the behavior.

It("Depth Check Sample: Case 1 - Depth 3 should pass for 3-level deep check", func() {
db, err := factories.DatabaseFactory(
config.Database{
Engine: "memory",
},
)

Expect(err).ShouldNot(HaveOccurred())
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we're using the "strict" mocking here? Where we don't just stub things, but expect things to be called in a certain order, etc. ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


conf, err := newSchema(depthCheckSchema)
Expect(err).ShouldNot(HaveOccurred())

schemaWriter := factories.SchemaWriterFactory(db)
err = schemaWriter.WriteSchema(context.Background(), conf)

Expect(err).ShouldNot(HaveOccurred())

type check struct {
entity string
subject string
depth int32
assertions map[string]base.CheckResult
}

relationships := []string{
"top:1#parent@middle:1#...",
"middle:1#parent@bottom:1#...",
"bottom:1#member@user:1",

"top:2#parent@middle:2#...",
"middle:2#parent@bottom:2#...",
"bottom:2#member@user:2",

"top:3#parent@middle:3#...",
"middle:3#parent@bottom:3#...",
"bottom:3#member@user:3",

"top:4#parent@middle:4#...",
"middle:4#parent@bottom:4#...",
"bottom:4#member@user:4",

"top:5#parent@middle:5#...",
"middle:5#parent@bottom:5#...",
"bottom:5#member@user:5",

"top:6#parent@middle:6#...",
"middle:6#parent@bottom:6#...",
"bottom:6#member@user:6",

"top:7#parent@middle:7#...",
"middle:7#parent@bottom:7#...",
"bottom:7#member@user:7",

"top:8#parent@middle:8#...",
"middle:8#parent@bottom:8#...",
"bottom:8#member@user:8",

"top:9#parent@middle:9#...",
"middle:9#parent@bottom:9#...",
"bottom:9#member@user:9",

"top:10#parent@middle:10#...",
"middle:10#parent@bottom:10#...",
"bottom:10#member@user:10",
}

var checks []check
for i := 1; i <= 10; i++ {
entity := fmt.Sprintf("top:%d", i)
subject := fmt.Sprintf("user:%d", i)
checks = append(checks, check{
entity: entity,
subject: subject,
depth: 3,
assertions: map[string]base.CheckResult{
"check": base.CheckResult_CHECK_RESULT_ALLOWED,
},
})
}

schemaReader := factories.SchemaReaderFactory(db)
dataReader := factories.DataReaderFactory(db)
dataWriter := factories.DataWriterFactory(db)

checkEngine := NewCheckEngine(schemaReader, dataReader)

invoker := invoke.NewDirectInvoker(
schemaReader,
dataReader,
checkEngine,
nil,
nil,
nil,
)

checkEngine.SetInvoker(invoker)

var tuples []*base.Tuple

for _, relationship := range relationships {
t, err := tuple.Tuple(relationship)
Expect(err).ShouldNot(HaveOccurred())
tuples = append(tuples, t)
}

_, err = dataWriter.Write(context.Background(), "t1", database.NewTupleCollection(tuples...), database.NewAttributeCollection())
Expect(err).ShouldNot(HaveOccurred())

for _, check := range checks {
entity, err := tuple.E(check.entity)
Expect(err).ShouldNot(HaveOccurred())

ear, err := tuple.EAR(check.subject)
Expect(err).ShouldNot(HaveOccurred())

subject := &base.Subject{
Type: ear.GetEntity().GetType(),
Id: ear.GetEntity().GetId(),
Relation: ear.GetRelation(),
}

for permission, res := range check.assertions {
response, err := invoker.Check(context.Background(), &base.PermissionCheckRequest{
TenantId: "t1",
Entity: entity,
Subject: subject,
Permission: permission,
Metadata: &base.PermissionCheckRequestMetadata{
SnapToken: token.NewNoopToken().Encode().String(),
SchemaVersion: "",
Depth: check.depth,
},
})

Expect(err).ShouldNot(HaveOccurred())
Expect(res).Should(Equal(response.GetCan()))
}
}
})

It("Depth Check Sample: Case 2 - Depth 2 should fail for 3-level deep check", func() {
db, err := factories.DatabaseFactory(
config.Database{
Engine: "memory",
},
)

Expect(err).ShouldNot(HaveOccurred())

conf, err := newSchema(depthCheckSchema)
Expect(err).ShouldNot(HaveOccurred())

schemaWriter := factories.SchemaWriterFactory(db)
err = schemaWriter.WriteSchema(context.Background(), conf)

Expect(err).ShouldNot(HaveOccurred())

type check struct {
entity string
subject string
depth int32
assertions map[string]base.CheckResult
}

relationships := []string{
"top:1#parent@middle:1#...",
"middle:1#parent@bottom:1#...",
"bottom:1#member@user:1",

"top:2#parent@middle:2#...",
"middle:2#parent@bottom:2#...",
"bottom:2#member@user:2",

"top:3#parent@middle:3#...",
"middle:3#parent@bottom:3#...",
"bottom:3#member@user:3",

"top:4#parent@middle:4#...",
"middle:4#parent@bottom:4#...",
"bottom:4#member@user:4",

"top:5#parent@middle:5#...",
"middle:5#parent@bottom:5#...",
"bottom:5#member@user:5",

"top:6#parent@middle:6#...",
"middle:6#parent@bottom:6#...",
"bottom:6#member@user:6",

"top:7#parent@middle:7#...",
"middle:7#parent@bottom:7#...",
"bottom:7#member@user:7",

"top:8#parent@middle:8#...",
"middle:8#parent@bottom:8#...",
"bottom:8#member@user:8",

"top:9#parent@middle:9#...",
"middle:9#parent@bottom:9#...",
"bottom:9#member@user:9",

"top:10#parent@middle:10#...",
"middle:10#parent@bottom:10#...",
"bottom:10#member@user:10",
}

var checks []check
for i := 1; i <= 10; i++ {
entity := fmt.Sprintf("top:%d", i)
subject := fmt.Sprintf("user:%d", i)
checks = append(checks, check{
entity: entity,
subject: subject,
depth: 2,
assertions: map[string]base.CheckResult{
"check": base.CheckResult_CHECK_RESULT_DENIED,
},
})
}

schemaReader := factories.SchemaReaderFactory(db)
dataReader := factories.DataReaderFactory(db)
dataWriter := factories.DataWriterFactory(db)

checkEngine := NewCheckEngine(schemaReader, dataReader)

invoker := invoke.NewDirectInvoker(
schemaReader,
dataReader,
checkEngine,
nil,
nil,
nil,
)

checkEngine.SetInvoker(invoker)

var tuples []*base.Tuple

for _, relationship := range relationships {
t, err := tuple.Tuple(relationship)
Expect(err).ShouldNot(HaveOccurred())
tuples = append(tuples, t)
}

_, err = dataWriter.Write(context.Background(), "t1", database.NewTupleCollection(tuples...), database.NewAttributeCollection())
Expect(err).ShouldNot(HaveOccurred())

for _, check := range checks {
entity, err := tuple.E(check.entity)
Expect(err).ShouldNot(HaveOccurred())

ear, err := tuple.EAR(check.subject)
Expect(err).ShouldNot(HaveOccurred())

subject := &base.Subject{
Type: ear.GetEntity().GetType(),
Id: ear.GetEntity().GetId(),
Relation: ear.GetRelation(),
}

for permission := range check.assertions {
response, err := invoker.Check(context.Background(), &base.PermissionCheckRequest{
TenantId: "t1",
Entity: entity,
Subject: subject,
Permission: permission,
Metadata: &base.PermissionCheckRequestMetadata{
SnapToken: token.NewNoopToken().Encode().String(),
SchemaVersion: "",
Depth: check.depth,
},
})

Expect(err).Should(HaveOccurred())
Expect(err.Error()).Should(ContainSubstring("ERROR_CODE_DEPTH_NOT_ENOUGH"))
Expect(response.GetCan()).Should(Equal(base.CheckResult_CHECK_RESULT_DENIED))
}
}
})
})
})
6 changes: 4 additions & 2 deletions internal/invoke/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,12 @@ func (invoker *DirectInvoker) Check(ctx context.Context, request *base.Permissio
}
}

atomic.AddInt32(&request.GetMetadata().Depth, -1)
// Create a copy of the request to safely decrement depth without mutating the original.
nextRequest := request.CloneVT()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if we do mutate the original?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the depth is decremented on every deeper call. previously if a lower level had multiple leaves, it ended up decreasing the depth multiple times, when it should have only decreased once

nextRequest.Metadata.Depth = request.GetMetadata().Depth - 1

// Perform the actual permission check using the provided request.
response, err = invoker.cc.Check(ctx, request)
response, err = invoker.cc.Check(ctx, nextRequest)
if err != nil {
span.RecordError(err)
span.SetStatus(otelCodes.Error, err.Error())
Expand Down
2 changes: 1 addition & 1 deletion internal/invoke/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// checkDepth validates that the request has sufficient depth for permission checks
func checkDepth(request *base.PermissionCheckRequest) error {
if atomic.LoadInt32(&request.GetMetadata().Depth) <= 0 { // Check depth is positive
if atomic.LoadInt32(&request.GetMetadata().Depth) < 0 { // Check depth is not negative
return errors.New(base.ErrorCode_ERROR_CODE_DEPTH_NOT_ENOUGH.String())
}
return nil
Expand Down