Skip to content

Commit 975fa38

Browse files
authored
Merge pull request #2740 from Permify/ufuk/fix-depth-decrement-error
fix: depth decrementing was passed by reference and all same depth ch…
2 parents 8411364 + 0911720 commit 975fa38

3 files changed

Lines changed: 308 additions & 3 deletions

File tree

internal/engines/check_test.go

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,4 +2069,307 @@ var _ = Describe("check-engine", func() {
20692069
}
20702070
})
20712071
})
2072+
2073+
// DEPTH CHECK SAMPLE (3-level deep check)
2074+
depthCheckSchema := `
2075+
entity user {}
2076+
2077+
entity bottom {
2078+
relation member @user
2079+
permission check = member
2080+
}
2081+
2082+
entity middle {
2083+
relation parent @bottom
2084+
permission check = parent.check
2085+
}
2086+
2087+
entity top {
2088+
relation parent @middle
2089+
permission check = parent.check
2090+
}
2091+
`
2092+
2093+
Context("Depth Check Sample: Check", func() {
2094+
It("Depth Check Sample: Case 1 - Depth 3 should pass for 3-level deep check", func() {
2095+
db, err := factories.DatabaseFactory(
2096+
config.Database{
2097+
Engine: "memory",
2098+
},
2099+
)
2100+
2101+
Expect(err).ShouldNot(HaveOccurred())
2102+
2103+
conf, err := newSchema(depthCheckSchema)
2104+
Expect(err).ShouldNot(HaveOccurred())
2105+
2106+
schemaWriter := factories.SchemaWriterFactory(db)
2107+
err = schemaWriter.WriteSchema(context.Background(), conf)
2108+
2109+
Expect(err).ShouldNot(HaveOccurred())
2110+
2111+
type check struct {
2112+
entity string
2113+
subject string
2114+
depth int32
2115+
assertions map[string]base.CheckResult
2116+
}
2117+
2118+
relationships := []string{
2119+
"top:1#parent@middle:1#...",
2120+
"middle:1#parent@bottom:1#...",
2121+
"bottom:1#member@user:1",
2122+
2123+
"top:2#parent@middle:2#...",
2124+
"middle:2#parent@bottom:2#...",
2125+
"bottom:2#member@user:2",
2126+
2127+
"top:3#parent@middle:3#...",
2128+
"middle:3#parent@bottom:3#...",
2129+
"bottom:3#member@user:3",
2130+
2131+
"top:4#parent@middle:4#...",
2132+
"middle:4#parent@bottom:4#...",
2133+
"bottom:4#member@user:4",
2134+
2135+
"top:5#parent@middle:5#...",
2136+
"middle:5#parent@bottom:5#...",
2137+
"bottom:5#member@user:5",
2138+
2139+
"top:6#parent@middle:6#...",
2140+
"middle:6#parent@bottom:6#...",
2141+
"bottom:6#member@user:6",
2142+
2143+
"top:7#parent@middle:7#...",
2144+
"middle:7#parent@bottom:7#...",
2145+
"bottom:7#member@user:7",
2146+
2147+
"top:8#parent@middle:8#...",
2148+
"middle:8#parent@bottom:8#...",
2149+
"bottom:8#member@user:8",
2150+
2151+
"top:9#parent@middle:9#...",
2152+
"middle:9#parent@bottom:9#...",
2153+
"bottom:9#member@user:9",
2154+
2155+
"top:10#parent@middle:10#...",
2156+
"middle:10#parent@bottom:10#...",
2157+
"bottom:10#member@user:10",
2158+
}
2159+
2160+
var checks []check
2161+
for i := 1; i <= 10; i++ {
2162+
entity := fmt.Sprintf("top:%d", i)
2163+
subject := fmt.Sprintf("user:%d", i)
2164+
checks = append(checks, check{
2165+
entity: entity,
2166+
subject: subject,
2167+
depth: 3,
2168+
assertions: map[string]base.CheckResult{
2169+
"check": base.CheckResult_CHECK_RESULT_ALLOWED,
2170+
},
2171+
})
2172+
}
2173+
2174+
schemaReader := factories.SchemaReaderFactory(db)
2175+
dataReader := factories.DataReaderFactory(db)
2176+
dataWriter := factories.DataWriterFactory(db)
2177+
2178+
checkEngine := NewCheckEngine(schemaReader, dataReader)
2179+
2180+
invoker := invoke.NewDirectInvoker(
2181+
schemaReader,
2182+
dataReader,
2183+
checkEngine,
2184+
nil,
2185+
nil,
2186+
nil,
2187+
)
2188+
2189+
checkEngine.SetInvoker(invoker)
2190+
2191+
var tuples []*base.Tuple
2192+
2193+
for _, relationship := range relationships {
2194+
t, err := tuple.Tuple(relationship)
2195+
Expect(err).ShouldNot(HaveOccurred())
2196+
tuples = append(tuples, t)
2197+
}
2198+
2199+
_, err = dataWriter.Write(context.Background(), "t1", database.NewTupleCollection(tuples...), database.NewAttributeCollection())
2200+
Expect(err).ShouldNot(HaveOccurred())
2201+
2202+
for _, check := range checks {
2203+
entity, err := tuple.E(check.entity)
2204+
Expect(err).ShouldNot(HaveOccurred())
2205+
2206+
ear, err := tuple.EAR(check.subject)
2207+
Expect(err).ShouldNot(HaveOccurred())
2208+
2209+
subject := &base.Subject{
2210+
Type: ear.GetEntity().GetType(),
2211+
Id: ear.GetEntity().GetId(),
2212+
Relation: ear.GetRelation(),
2213+
}
2214+
2215+
for permission, res := range check.assertions {
2216+
response, err := invoker.Check(context.Background(), &base.PermissionCheckRequest{
2217+
TenantId: "t1",
2218+
Entity: entity,
2219+
Subject: subject,
2220+
Permission: permission,
2221+
Metadata: &base.PermissionCheckRequestMetadata{
2222+
SnapToken: token.NewNoopToken().Encode().String(),
2223+
SchemaVersion: "",
2224+
Depth: check.depth,
2225+
},
2226+
})
2227+
2228+
Expect(err).ShouldNot(HaveOccurred())
2229+
Expect(res).Should(Equal(response.GetCan()))
2230+
}
2231+
}
2232+
})
2233+
2234+
It("Depth Check Sample: Case 2 - Depth 2 should fail for 3-level deep check", func() {
2235+
db, err := factories.DatabaseFactory(
2236+
config.Database{
2237+
Engine: "memory",
2238+
},
2239+
)
2240+
2241+
Expect(err).ShouldNot(HaveOccurred())
2242+
2243+
conf, err := newSchema(depthCheckSchema)
2244+
Expect(err).ShouldNot(HaveOccurred())
2245+
2246+
schemaWriter := factories.SchemaWriterFactory(db)
2247+
err = schemaWriter.WriteSchema(context.Background(), conf)
2248+
2249+
Expect(err).ShouldNot(HaveOccurred())
2250+
2251+
type check struct {
2252+
entity string
2253+
subject string
2254+
depth int32
2255+
assertions map[string]base.CheckResult
2256+
}
2257+
2258+
relationships := []string{
2259+
"top:1#parent@middle:1#...",
2260+
"middle:1#parent@bottom:1#...",
2261+
"bottom:1#member@user:1",
2262+
2263+
"top:2#parent@middle:2#...",
2264+
"middle:2#parent@bottom:2#...",
2265+
"bottom:2#member@user:2",
2266+
2267+
"top:3#parent@middle:3#...",
2268+
"middle:3#parent@bottom:3#...",
2269+
"bottom:3#member@user:3",
2270+
2271+
"top:4#parent@middle:4#...",
2272+
"middle:4#parent@bottom:4#...",
2273+
"bottom:4#member@user:4",
2274+
2275+
"top:5#parent@middle:5#...",
2276+
"middle:5#parent@bottom:5#...",
2277+
"bottom:5#member@user:5",
2278+
2279+
"top:6#parent@middle:6#...",
2280+
"middle:6#parent@bottom:6#...",
2281+
"bottom:6#member@user:6",
2282+
2283+
"top:7#parent@middle:7#...",
2284+
"middle:7#parent@bottom:7#...",
2285+
"bottom:7#member@user:7",
2286+
2287+
"top:8#parent@middle:8#...",
2288+
"middle:8#parent@bottom:8#...",
2289+
"bottom:8#member@user:8",
2290+
2291+
"top:9#parent@middle:9#...",
2292+
"middle:9#parent@bottom:9#...",
2293+
"bottom:9#member@user:9",
2294+
2295+
"top:10#parent@middle:10#...",
2296+
"middle:10#parent@bottom:10#...",
2297+
"bottom:10#member@user:10",
2298+
}
2299+
2300+
var checks []check
2301+
for i := 1; i <= 10; i++ {
2302+
entity := fmt.Sprintf("top:%d", i)
2303+
subject := fmt.Sprintf("user:%d", i)
2304+
checks = append(checks, check{
2305+
entity: entity,
2306+
subject: subject,
2307+
depth: 2,
2308+
assertions: map[string]base.CheckResult{
2309+
"check": base.CheckResult_CHECK_RESULT_DENIED,
2310+
},
2311+
})
2312+
}
2313+
2314+
schemaReader := factories.SchemaReaderFactory(db)
2315+
dataReader := factories.DataReaderFactory(db)
2316+
dataWriter := factories.DataWriterFactory(db)
2317+
2318+
checkEngine := NewCheckEngine(schemaReader, dataReader)
2319+
2320+
invoker := invoke.NewDirectInvoker(
2321+
schemaReader,
2322+
dataReader,
2323+
checkEngine,
2324+
nil,
2325+
nil,
2326+
nil,
2327+
)
2328+
2329+
checkEngine.SetInvoker(invoker)
2330+
2331+
var tuples []*base.Tuple
2332+
2333+
for _, relationship := range relationships {
2334+
t, err := tuple.Tuple(relationship)
2335+
Expect(err).ShouldNot(HaveOccurred())
2336+
tuples = append(tuples, t)
2337+
}
2338+
2339+
_, err = dataWriter.Write(context.Background(), "t1", database.NewTupleCollection(tuples...), database.NewAttributeCollection())
2340+
Expect(err).ShouldNot(HaveOccurred())
2341+
2342+
for _, check := range checks {
2343+
entity, err := tuple.E(check.entity)
2344+
Expect(err).ShouldNot(HaveOccurred())
2345+
2346+
ear, err := tuple.EAR(check.subject)
2347+
Expect(err).ShouldNot(HaveOccurred())
2348+
2349+
subject := &base.Subject{
2350+
Type: ear.GetEntity().GetType(),
2351+
Id: ear.GetEntity().GetId(),
2352+
Relation: ear.GetRelation(),
2353+
}
2354+
2355+
for permission := range check.assertions {
2356+
response, err := invoker.Check(context.Background(), &base.PermissionCheckRequest{
2357+
TenantId: "t1",
2358+
Entity: entity,
2359+
Subject: subject,
2360+
Permission: permission,
2361+
Metadata: &base.PermissionCheckRequestMetadata{
2362+
SnapToken: token.NewNoopToken().Encode().String(),
2363+
SchemaVersion: "",
2364+
Depth: check.depth,
2365+
},
2366+
})
2367+
2368+
Expect(err).Should(HaveOccurred())
2369+
Expect(err.Error()).Should(ContainSubstring("ERROR_CODE_DEPTH_NOT_ENOUGH"))
2370+
Expect(response.GetCan()).Should(Equal(base.CheckResult_CHECK_RESULT_DENIED))
2371+
}
2372+
}
2373+
})
2374+
})
20722375
})

internal/invoke/invoke.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,12 @@ func (invoker *DirectInvoker) Check(ctx context.Context, request *base.Permissio
166166
}
167167
}
168168

169-
atomic.AddInt32(&request.GetMetadata().Depth, -1)
169+
// Create a copy of the request to safely decrement depth without mutating the original.
170+
nextRequest := request.CloneVT()
171+
nextRequest.Metadata.Depth = request.GetMetadata().Depth - 1
170172

171173
// Perform the actual permission check using the provided request.
172-
response, err = invoker.cc.Check(ctx, request)
174+
response, err = invoker.cc.Check(ctx, nextRequest)
173175
if err != nil {
174176
span.RecordError(err)
175177
span.SetStatus(otelCodes.Error, err.Error())

internal/invoke/utils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99

1010
// checkDepth validates that the request has sufficient depth for permission checks
1111
func checkDepth(request *base.PermissionCheckRequest) error {
12-
if atomic.LoadInt32(&request.GetMetadata().Depth) <= 0 { // Check depth is positive
12+
if atomic.LoadInt32(&request.GetMetadata().Depth) < 0 { // Check depth is not negative
1313
return errors.New(base.ErrorCode_ERROR_CODE_DEPTH_NOT_ENOUGH.String())
1414
}
1515
return nil

0 commit comments

Comments
 (0)