@@ -2,9 +2,27 @@ package service
22
33import (
44 "context"
5+ "errors"
6+ "fmt"
7+ "sync"
58
9+ "github.com/samber/lo"
10+ "golang.org/x/sync/semaphore"
11+
12+ "github.com/openmeterio/openmeter/openmeter/billing"
13+ "github.com/openmeterio/openmeter/openmeter/billing/charges/meta"
614 "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased"
15+ usagebasedrating "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased/service/rating"
16+ "github.com/openmeterio/openmeter/openmeter/customer"
17+ "github.com/openmeterio/openmeter/pkg/clock"
718 "github.com/openmeterio/openmeter/pkg/framework/transaction"
19+ "github.com/openmeterio/openmeter/pkg/ref"
20+ "github.com/openmeterio/openmeter/pkg/slicesx"
21+ )
22+
23+ const (
24+ // defaultMaxParallelRatingsPerRequest is the number of workers to use for the rating (fetching from CH).
25+ defaultMaxParallelRatingsPerRequest = 5
826)
927
1028func (s * service ) GetByIDs (ctx context.Context , input usagebased.GetByIDsInput ) ([]usagebased.Charge , error ) {
@@ -13,7 +31,19 @@ func (s *service) GetByIDs(ctx context.Context, input usagebased.GetByIDsInput)
1331 }
1432
1533 return transaction .Run (ctx , s .adapter , func (ctx context.Context ) ([]usagebased.Charge , error ) {
16- return s .adapter .GetByIDs (ctx , input )
34+ charges , err := s .adapter .GetByIDs (ctx , input )
35+ if err != nil {
36+ return nil , err
37+ }
38+
39+ if input .Expands .Has (meta .ExpandRealtimeUsage ) {
40+ charges , err = s .expandChargesUsage (ctx , input .Namespace , charges )
41+ if err != nil {
42+ return nil , err
43+ }
44+ }
45+
46+ return charges , nil
1747 })
1848}
1949
@@ -23,6 +53,138 @@ func (s *service) GetByID(ctx context.Context, input usagebased.GetByIDInput) (u
2353 }
2454
2555 return transaction .Run (ctx , s .adapter , func (ctx context.Context ) (usagebased.Charge , error ) {
26- return s .adapter .GetByID (ctx , input )
56+ charge , err := s .adapter .GetByID (ctx , input )
57+ if err != nil {
58+ return usagebased.Charge {}, err
59+ }
60+
61+ if input .Expands .Has (meta .ExpandRealtimeUsage ) {
62+ totals , err := s .GetCurrentTotals (ctx , usagebased.GetCurrentTotalsInput {
63+ ChargeID : charge .GetChargeID (),
64+ })
65+ if err != nil {
66+ return usagebased.Charge {}, err
67+ }
68+
69+ charge .Expands .RealtimeUsage = & totals .DueTotals
70+ }
71+
72+ return charge , nil
73+ })
74+ }
75+
76+ func (s * service ) expandChargesUsage (ctx context.Context , namespace string , charges usagebased.Charges ) (usagebased.Charges , error ) {
77+ // Fetch unique customers from the charges to avoid duplicate calls to the customer override service.
78+ uniqueCustomers := lo .Uniq (lo .Map (charges , func (charge usagebased.Charge , _ int ) customer.CustomerID {
79+ return charge .GetCustomerID ()
80+ }))
81+
82+ customerOverridesById := make (map [customer.CustomerID ]billing.CustomerOverrideWithDetails )
83+ for _ , customerID := range uniqueCustomers {
84+ customerOverride , err := s .customerOverrideService .GetCustomerOverride (ctx , billing.GetCustomerOverrideInput {
85+ Customer : customerID ,
86+ Expand : billing.CustomerOverrideExpand {
87+ Customer : true ,
88+ },
89+ })
90+ if err != nil {
91+ return nil , err
92+ }
93+ customerOverridesById [customerID ] = customerOverride
94+ }
95+
96+ // Fetch all references featureMeters in bulk
97+ referencedFeatureMeters := lo .Uniq (lo .Map (charges , func (charge usagebased.Charge , _ int ) ref.IDOrKey {
98+ return charge .GetFeatureKeyOrID ()
99+ }))
100+
101+ featureMeters , err := s .featureService .ResolveFeatureMeters (ctx , namespace , referencedFeatureMeters ... )
102+ if err != nil {
103+ return nil , err
104+ }
105+
106+ // Let's do the rating for each charge
107+ sem := semaphore .NewWeighted (int64 (defaultMaxParallelRatingsPerRequest ))
108+ storedAt := clock .Now ()
109+
110+ errCh := make (chan error , len (charges ))
111+ ratingResults := sync.Map {}
112+
113+ var wg sync.WaitGroup
114+
115+ for _ , charge := range charges {
116+ featureMeter , err := charge .ResolveFeatureMeter (featureMeters )
117+ if err != nil {
118+ errCh <- fmt .Errorf ("resolving feature meter: %w" , err )
119+ break
120+ }
121+
122+ err = sem .Acquire (ctx , 1 )
123+ if err != nil {
124+ // Clean up and stop the loop
125+ errCh <- fmt .Errorf ("acquiring worker slot: %w" , err )
126+ break
127+ }
128+
129+ wg .Go (func () {
130+ defer sem .Release (1 )
131+ var err error
132+ defer func () {
133+ if err != nil {
134+ errCh <- err
135+ }
136+ }()
137+
138+ defer func () {
139+ if r := recover (); r != nil {
140+ err = fmt .Errorf ("rating charge %s: %v" , charge .ID , r )
141+ }
142+ }()
143+
144+ var ratingResult usagebasedrating.GetRatingForUsageResult
145+ ratingResult , err = s .rater .GetRatingForUsage (ctx , usagebasedrating.GetRatingForUsageInput {
146+ Charge : charge ,
147+ Customer : customerOverridesById [charge .GetCustomerID ()],
148+ FeatureMeter : featureMeter ,
149+ StoredAtOffset : storedAt ,
150+ })
151+ if err != nil {
152+ err = fmt .Errorf ("rating charge %s: %w" , charge .ID , err )
153+ return
154+ }
155+
156+ ratingResults .Store (charge .GetChargeID (), ratingResult )
157+ })
158+ }
159+
160+ wg .Wait ()
161+
162+ close (errCh )
163+
164+ var errs []error
165+
166+ for err := range errCh {
167+ if err != nil {
168+ errs = append (errs , err )
169+ }
170+ }
171+
172+ if len (errs ) > 0 {
173+ return nil , errors .Join (errs ... )
174+ }
175+
176+ return slicesx .MapWithErr (charges , func (charge usagebased.Charge ) (usagebased.Charge , error ) {
177+ ratingResultAny , ok := ratingResults .Load (charge .GetChargeID ())
178+ if ! ok {
179+ return charge , fmt .Errorf ("rating result not found for charge %s" , charge .ID )
180+ }
181+
182+ ratingResult , ok := ratingResultAny .(usagebasedrating.GetRatingForUsageResult )
183+ if ! ok {
184+ return charge , fmt .Errorf ("rating result not found for charge %s" , charge .ID )
185+ }
186+
187+ charge .Expands .RealtimeUsage = & ratingResult .Totals
188+ return charge , nil
27189 })
28190}
0 commit comments