Skip to content

Commit 91ea6e1

Browse files
authored
chore: improve entitlement lookup performance (#4149)
1 parent 69baee2 commit 91ea6e1

4 files changed

Lines changed: 237 additions & 20 deletions

File tree

openmeter/entitlement/adapter/entitlement.go

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"time"
88

99
"entgo.io/ent/dialect/sql"
10+
"github.com/oklog/ulid/v2"
1011
"github.com/samber/lo"
1112

1213
customeradapter "github.com/openmeterio/openmeter/openmeter/customer/adapter"
@@ -86,10 +87,9 @@ func (a *entitlementDBAdapter) GetActiveEntitlementOfCustomerAt(ctx context.Cont
8687
Where(EntitlementActiveAt(at)...).
8788
Where(
8889
db_entitlement.Or(db_entitlement.DeletedAtGT(at), db_entitlement.DeletedAtIsNil()),
90+
db_entitlement.CustomerID(customerID),
8991
db_entitlement.HasCustomerWith(
90-
customerdb.Namespace(namespace),
9192
customerNotDeletedAt(at),
92-
customerdb.ID(customerID),
9393
),
9494
db_entitlement.Namespace(namespace),
9595
db_entitlement.FeatureKey(featureKey),
@@ -313,6 +313,9 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti
313313
a,
314314
func(ctx context.Context, repo *entitlementDBAdapter) (pagination.Result[entitlement.Entitlement], error) {
315315
now := clock.Now().UTC()
316+
response := pagination.Result[entitlement.Entitlement]{
317+
Page: params.Page,
318+
}
316319

317320
query := repo.db.Entitlement.Query()
318321

@@ -335,15 +338,29 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti
335338
}
336339

337340
if len(params.CustomerKeys) > 0 {
338-
query = query.Where(db_entitlement.HasCustomerWith(
339-
customerdb.KeyIn(params.CustomerKeys...),
340-
))
341+
customerQuery := repo.db.Customer.Query().
342+
Where(customerdb.KeyIn(params.CustomerKeys...))
343+
344+
if len(params.Namespaces) > 0 {
345+
customerQuery = customerQuery.Where(customerdb.NamespaceIn(params.Namespaces...))
346+
}
347+
348+
customerIDs, err := customerQuery.IDs(ctx)
349+
if err != nil {
350+
return response, err
351+
}
352+
353+
if len(customerIDs) == 0 {
354+
response.Items = []entitlement.Entitlement{}
355+
response.TotalCount = 0
356+
return response, nil
357+
}
358+
359+
query = query.Where(db_entitlement.CustomerIDIn(customerIDs...))
341360
}
342361

343362
if len(params.CustomerIDs) > 0 {
344-
query = query.Where(db_entitlement.HasCustomerWith(
345-
customerdb.IDIn(params.CustomerIDs...),
346-
))
363+
query = query.Where(db_entitlement.CustomerIDIn(params.CustomerIDs...))
347364
}
348365

349366
if len(params.EntitlementTypes) > 0 {
@@ -357,16 +374,30 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti
357374
}
358375

359376
if len(params.FeatureIDsOrKeys) > 0 {
360-
var ep predicate.Entitlement
361-
for i, idOrKey := range params.FeatureIDsOrKeys {
362-
p := db_entitlement.Or(db_entitlement.FeatureID(idOrKey), db_entitlement.FeatureKey(idOrKey))
363-
if i == 0 {
364-
ep = p
365-
continue
377+
featureIDs := make([]string, 0, len(params.FeatureIDsOrKeys))
378+
featureKeys := make([]string, 0, len(params.FeatureIDsOrKeys))
379+
380+
for _, idOrKey := range params.FeatureIDsOrKeys {
381+
if _, err := ulid.Parse(idOrKey); err == nil {
382+
featureIDs = append(featureIDs, idOrKey)
383+
} else {
384+
featureKeys = append(featureKeys, idOrKey)
366385
}
367-
ep = db_entitlement.Or(ep, p)
368386
}
369-
query = query.Where(ep)
387+
388+
switch {
389+
case len(featureIDs) > 0 && len(featureKeys) > 0:
390+
query = query.Where(
391+
db_entitlement.Or(
392+
db_entitlement.FeatureIDIn(featureIDs...),
393+
db_entitlement.FeatureKeyIn(featureKeys...),
394+
),
395+
)
396+
case len(featureIDs) > 0:
397+
query = query.Where(db_entitlement.FeatureIDIn(featureIDs...))
398+
case len(featureKeys) > 0:
399+
query = query.Where(db_entitlement.FeatureKeyIn(featureKeys...))
400+
}
370401
}
371402

372403
if len(params.FeatureIDs) > 0 {
@@ -406,10 +437,6 @@ func (a *entitlementDBAdapter) ListEntitlements(ctx context.Context, params enti
406437
}
407438
}
408439

409-
response := pagination.Result[entitlement.Entitlement]{
410-
Page: params.Page,
411-
}
412-
413440
// we're using limit and offset
414441
if params.Page.IsZero() {
415442
if params.Limit > 0 {

openmeter/entitlement/adapter/entitlement_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,33 @@ func createCustomerWithSubject(t *testing.T, subjectRepo subject.Service, custom
117117
return cust
118118
}
119119

120+
func createCustomerWithSubjectAndKey(t *testing.T, subjectRepo subject.Service, customerRepo customer.Adapter, namespace string, subjectKey string, customerKey string) *customer.Customer {
121+
t.Helper()
122+
123+
ctx, cancel := context.WithCancel(context.Background())
124+
defer cancel()
125+
126+
_, err := subjectRepo.Create(ctx, subject.CreateInput{
127+
Namespace: namespace,
128+
Key: subjectKey,
129+
})
130+
require.NoError(t, err)
131+
132+
cust, err := customerRepo.CreateCustomer(ctx, customer.CreateCustomerInput{
133+
Namespace: namespace,
134+
CustomerMutate: customer.CustomerMutate{
135+
Key: lo.ToPtr(customerKey),
136+
Name: "Customer 1",
137+
UsageAttribution: &customer.CustomerUsageAttribution{
138+
SubjectKeys: []string{subjectKey},
139+
},
140+
},
141+
})
142+
require.NoError(t, err)
143+
144+
return cust
145+
}
146+
120147
func TestUpsertEntitlementCurrentPeriods(t *testing.T) {
121148
ns := "ns1"
122149
featureKey := "feature1"
@@ -518,3 +545,78 @@ func TestEntitlementLoadsSubjectAndCustomerAndPreservesAcrossTypedMapping(t *tes
518545
}
519546
}
520547
}
548+
549+
func TestListEntitlementsFiltersByCustomerKeysAndFeatureIDsOrKeys(t *testing.T) {
550+
ctx := context.Background()
551+
ns := "ns-list-filters"
552+
553+
repo, cleanup := setup(t)
554+
defer cleanup()
555+
556+
featureByKey, err := repo.featureRepo.CreateFeature(ctx, feature.CreateFeatureInputs{
557+
Namespace: ns,
558+
Key: "free_plan_usage",
559+
Name: "Free plan usage",
560+
})
561+
require.NoError(t, err)
562+
563+
featureByID, err := repo.featureRepo.CreateFeature(ctx, feature.CreateFeatureInputs{
564+
Namespace: ns,
565+
Key: "pro_plan_usage",
566+
Name: "Pro plan usage",
567+
})
568+
require.NoError(t, err)
569+
570+
customerA := createCustomerWithSubjectAndKey(t, repo.subjectRepo, repo.customerRepo, ns, "subject-a", "customer-a")
571+
customerB := createCustomerWithSubjectAndKey(t, repo.subjectRepo, repo.customerRepo, ns, "subject-b", "customer-b")
572+
573+
entA, err := repo.entRepo.CreateEntitlement(ctx, entitlement.CreateEntitlementRepoInputs{
574+
Namespace: ns,
575+
FeatureID: featureByKey.ID,
576+
FeatureKey: featureByKey.Key,
577+
UsageAttribution: customerA.GetUsageAttribution(),
578+
EntitlementType: entitlement.EntitlementTypeBoolean,
579+
})
580+
require.NoError(t, err)
581+
582+
entB, err := repo.entRepo.CreateEntitlement(ctx, entitlement.CreateEntitlementRepoInputs{
583+
Namespace: ns,
584+
FeatureID: featureByID.ID,
585+
FeatureKey: featureByID.Key,
586+
UsageAttribution: customerB.GetUsageAttribution(),
587+
EntitlementType: entitlement.EntitlementTypeBoolean,
588+
})
589+
require.NoError(t, err)
590+
591+
t.Run("Should filter by customer key and feature key", func(t *testing.T) {
592+
res, err := repo.entRepo.ListEntitlements(ctx, entitlement.ListEntitlementsParams{
593+
Namespaces: []string{ns},
594+
CustomerKeys: []string{"customer-a"},
595+
FeatureIDsOrKeys: []string{featureByKey.Key},
596+
})
597+
require.NoError(t, err)
598+
require.Len(t, res.Items, 1)
599+
require.Equal(t, entA.ID, res.Items[0].ID)
600+
})
601+
602+
t.Run("Should filter by customer key and feature ID", func(t *testing.T) {
603+
res, err := repo.entRepo.ListEntitlements(ctx, entitlement.ListEntitlementsParams{
604+
Namespaces: []string{ns},
605+
CustomerKeys: []string{"customer-b"},
606+
FeatureIDsOrKeys: []string{featureByID.ID},
607+
})
608+
require.NoError(t, err)
609+
require.Len(t, res.Items, 1)
610+
require.Equal(t, entB.ID, res.Items[0].ID)
611+
})
612+
613+
t.Run("Should return empty result without querying entitlements when customer key does not exist", func(t *testing.T) {
614+
res, err := repo.entRepo.ListEntitlements(ctx, entitlement.ListEntitlementsParams{
615+
Namespaces: []string{ns},
616+
CustomerKeys: []string{"missing-customer"},
617+
})
618+
require.NoError(t, err)
619+
require.Empty(t, res.Items)
620+
require.Zero(t, res.TotalCount)
621+
})
622+
}

openmeter/entitlement/service/service.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"sync"
77
"time"
88

9+
"github.com/oklog/ulid/v2"
910
"github.com/samber/lo"
1011
"golang.org/x/sync/errgroup"
1112
"golang.org/x/sync/semaphore"
@@ -226,10 +227,17 @@ func (c *service) GetEntitlementsOfCustomer(ctx context.Context, namespace strin
226227
}
227228

228229
func (c *service) GetEntitlementOfCustomerAt(ctx context.Context, namespace string, customerID string, idOrFeatureKey string, at time.Time) (*entitlement.Entitlement, error) {
230+
// Feature keys are forbidden from being valid ULIDs, so non-ULID inputs can
231+
// skip the guaranteed-miss entitlement-ID lookup.
232+
if _, err := ulid.Parse(idOrFeatureKey); err != nil {
233+
return c.entitlementRepo.GetActiveEntitlementOfCustomerAt(ctx, namespace, customerID, idOrFeatureKey, at)
234+
}
235+
229236
ent, err := c.entitlementRepo.GetEntitlement(ctx, models.NamespacedID{Namespace: namespace, ID: idOrFeatureKey})
230237
if _, ok := lo.ErrorsAs[*entitlement.NotFoundError](err); ok {
231238
ent, err = c.entitlementRepo.GetActiveEntitlementOfCustomerAt(ctx, namespace, customerID, idOrFeatureKey, at)
232239
}
240+
233241
return ent, err
234242
}
235243

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package service_test
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/samber/lo"
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/openmeterio/openmeter/openmeter/entitlement"
11+
"github.com/openmeterio/openmeter/openmeter/meter"
12+
"github.com/openmeterio/openmeter/openmeter/productcatalog/feature"
13+
"github.com/openmeterio/openmeter/openmeter/testutils"
14+
"github.com/openmeterio/openmeter/pkg/clock"
15+
)
16+
17+
func TestGetEntitlementOfCustomerAt(t *testing.T) {
18+
conn, deps := setupDependecies(t)
19+
defer deps.Teardown()
20+
21+
namespace := "ns-get-entitlement-of-customer-at"
22+
now := testutils.GetRFC3339Time(t, "2025-01-01T00:00:00Z")
23+
24+
clock.SetTime(now)
25+
defer clock.ResetTime()
26+
27+
mtr, err := deps.meterService.CreateMeter(t.Context(), meter.CreateMeterInput{
28+
Namespace: namespace,
29+
Name: "Meter 1",
30+
Key: "meter1",
31+
Description: nil,
32+
Aggregation: meter.MeterAggregationSum,
33+
EventType: "test",
34+
EventFrom: nil,
35+
ValueProperty: lo.ToPtr("$.value"),
36+
GroupBy: nil,
37+
})
38+
require.NoError(t, err)
39+
require.NotNil(t, mtr)
40+
createMeterInPG(t, deps.dbClient, mtr)
41+
42+
cust := createCustomerAndSubject(t, deps.subjectService, deps.customerService, namespace, "cust-1", "Customer 1")
43+
44+
feat, err := deps.featureRepo.CreateFeature(t.Context(), feature.CreateFeatureInputs{
45+
Key: "free_plan_usage",
46+
Name: "Free plan usage",
47+
Namespace: namespace,
48+
MeterID: &mtr.ID,
49+
})
50+
require.NoError(t, err)
51+
require.NotNil(t, feat)
52+
53+
ent, err := conn.CreateEntitlement(t.Context(), entitlement.CreateEntitlementInputs{
54+
Namespace: namespace,
55+
UsageAttribution: cust.GetUsageAttribution(),
56+
FeatureKey: &feat.Key,
57+
FeatureID: &feat.ID,
58+
EntitlementType: entitlement.EntitlementTypeBoolean,
59+
}, nil)
60+
require.NoError(t, err)
61+
require.NotNil(t, ent)
62+
63+
t.Run("Should resolve entitlement by feature key", func(t *testing.T) {
64+
res, err := conn.GetEntitlementOfCustomerAt(t.Context(), namespace, cust.ID, feat.Key, clock.Now().Add(time.Hour))
65+
require.NoError(t, err)
66+
require.NotNil(t, res)
67+
require.Equal(t, ent.ID, res.ID)
68+
require.Equal(t, feat.Key, res.FeatureKey)
69+
require.Equal(t, cust.ID, res.CustomerID)
70+
})
71+
72+
t.Run("Should resolve entitlement by entitlement ID", func(t *testing.T) {
73+
res, err := conn.GetEntitlementOfCustomerAt(t.Context(), namespace, cust.ID, ent.ID, clock.Now().Add(time.Hour))
74+
require.NoError(t, err)
75+
require.NotNil(t, res)
76+
require.Equal(t, ent.ID, res.ID)
77+
require.Equal(t, feat.Key, res.FeatureKey)
78+
require.Equal(t, cust.ID, res.CustomerID)
79+
})
80+
}

0 commit comments

Comments
 (0)