From a40fe3f2401fbd82a9cd9e8c45e6555f5c222abb Mon Sep 17 00:00:00 2001 From: Alex Goth <64845621+GAlexIHU@users.noreply.github.com> Date: Wed, 15 Apr 2026 16:49:00 +0200 Subject: [PATCH 1/3] feat(ledger): transaction listing API --- api/v3/handlers/customers/credits/handler.go | 11 + .../customers/credits/list_transactions.go | 347 ++++++++++++++++++ .../credits/list_transactions_test.go | 195 ++++++++++ api/v3/server/routes.go | 10 +- api/v3/server/server.go | 5 +- cmd/server/main.go | 2 + cmd/server/wire.go | 3 + cmd/server/wire_gen.go | 5 + openmeter/ledger/account/account.go | 3 +- openmeter/ledger/account/subaccount.go | 2 +- openmeter/ledger/annotations.go | 39 +- openmeter/ledger/chargeadapter/annotations.go | 71 ++++ .../ledger/chargeadapter/creditpurchase.go | 34 +- .../chargeadapter/creditpurchase_test.go | 59 +++ openmeter/ledger/chargeadapter/flatfee.go | 37 +- openmeter/ledger/chargeadapter/usagebased.go | 2 + .../ledger/chargeadapter/usagebased_test.go | 55 +++ openmeter/ledger/collector/collect.go | 19 +- openmeter/ledger/collector/correct.go | 19 +- openmeter/ledger/collector/service.go | 3 + openmeter/ledger/customerbalance/facade.go | 49 ++- .../ledger/customerbalance/facade_test.go | 50 +++ openmeter/ledger/customerbalance/noop.go | 2 +- openmeter/ledger/customerbalance/service.go | 6 +- .../ledger/customerbalance/service_test.go | 6 +- openmeter/ledger/historical/adapter/ledger.go | 170 +++++++++ .../ledger/historical/adapter/ledger_test.go | 236 +++++++++++- .../historical/adapter/sumentries_query.go | 19 + openmeter/ledger/historical/ledger.go | 16 + openmeter/ledger/historical/repo.go | 4 + openmeter/ledger/historical/transaction.go | 11 + openmeter/ledger/ledger_test.go | 2 +- openmeter/ledger/noop/noop.go | 9 +- openmeter/ledger/primitives.go | 58 ++- openmeter/ledger/query.go | 11 + .../ledger/transactions/correction_test.go | 8 + openmeter/server/router/router.go | 3 + openmeter/server/server.go | 2 + test/credits/sanity_test.go | 10 +- 39 files changed, 1534 insertions(+), 59 deletions(-) create mode 100644 api/v3/handlers/customers/credits/list_transactions.go create mode 100644 api/v3/handlers/customers/credits/list_transactions_test.go create mode 100644 openmeter/ledger/chargeadapter/annotations.go diff --git a/api/v3/handlers/customers/credits/handler.go b/api/v3/handlers/customers/credits/handler.go index 32f4107155..b478dbdb4e 100644 --- a/api/v3/handlers/customers/credits/handler.go +++ b/api/v3/handlers/customers/credits/handler.go @@ -3,13 +3,17 @@ package customerscredits import ( "context" + "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/openmeter/billing/creditgrant" "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/pkg/framework/transport/httptransport" ) type customerBalanceFacade interface { + GetBalance(ctx context.Context, input customerbalance.GetBalanceInput) (alpacadecimal.Decimal, error) GetBalances(ctx context.Context, input customerbalance.GetBalancesInput) ([]customerbalance.BalanceByCurrency, error) } @@ -18,6 +22,7 @@ type Handler interface { ListCreditGrants() ListCreditGrantsHandler CreateCreditGrant() CreateCreditGrantHandler GetCreditGrant() GetCreditGrantHandler + ListCreditTransactions() ListCreditTransactionsHandler } type handler struct { @@ -25,6 +30,8 @@ type handler struct { customerService customer.Service balanceFacade customerBalanceFacade creditGrantService creditgrant.Service + ledger ledger.Ledger + accountResolver ledger.AccountResolver options []httptransport.HandlerOption } @@ -33,6 +40,8 @@ func New( customerService customer.Service, balanceFacade customerBalanceFacade, creditGrantService creditgrant.Service, + ledger ledger.Ledger, + accountResolver ledger.AccountResolver, options ...httptransport.HandlerOption, ) Handler { return &handler{ @@ -40,6 +49,8 @@ func New( customerService: customerService, balanceFacade: balanceFacade, creditGrantService: creditGrantService, + ledger: ledger, + accountResolver: accountResolver, options: options, } } diff --git a/api/v3/handlers/customers/credits/list_transactions.go b/api/v3/handlers/customers/credits/list_transactions.go new file mode 100644 index 0000000000..763c9d8f07 --- /dev/null +++ b/api/v3/handlers/customers/credits/list_transactions.go @@ -0,0 +1,347 @@ +package customerscredits + +import ( + "context" + "fmt" + "net/http" + + "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" + + api "github.com/openmeterio/openmeter/api/v3" + "github.com/openmeterio/openmeter/api/v3/apierrors" + "github.com/openmeterio/openmeter/api/v3/response" + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ledger" + ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/framework/commonhttp" + "github.com/openmeterio/openmeter/pkg/framework/transport/httptransport" + "github.com/openmeterio/openmeter/pkg/pagination" +) + +type ( + ListCreditTransactionsRequest struct { + Namespace string + CustomerID string + Page pagination.Page + + TypeFilter *api.BillingCreditTransactionType + CurrencyFilter *currencyx.Code + } + ListCreditTransactionsResponse = response.PagePaginationResponse[api.BillingCreditTransaction] + ListCreditTransactionsParams struct { + CustomerID api.ULID + Params api.ListCreditTransactionsParams + } + ListCreditTransactionsHandler httptransport.HandlerWithArgs[ListCreditTransactionsRequest, ListCreditTransactionsResponse, ListCreditTransactionsParams] + mappedCreditTransaction struct { + API api.BillingCreditTransaction + Amount alpacadecimal.Decimal + Currency currencyx.Code + Cursor ledger.TransactionCursor + } +) + +func (h *handler) ListCreditTransactions() ListCreditTransactionsHandler { + return httptransport.NewHandlerWithArgs( + func(ctx context.Context, r *http.Request, args ListCreditTransactionsParams) (ListCreditTransactionsRequest, error) { + ns, err := h.resolveNamespace(ctx) + if err != nil { + return ListCreditTransactionsRequest{}, err + } + + page := pagination.NewPage(1, 20) + if args.Params.Page != nil { + page = pagination.NewPage( + lo.FromPtrOr(args.Params.Page.Number, 1), + lo.FromPtrOr(args.Params.Page.Size, 20), + ) + } + + if err := page.Validate(); err != nil { + return ListCreditTransactionsRequest{}, apierrors.NewBadRequestError(ctx, err, apierrors.InvalidParameters{ + { + Field: "page", + Reason: err.Error(), + Source: apierrors.InvalidParamSourceQuery, + }, + }) + } + + req := ListCreditTransactionsRequest{ + Namespace: ns, + CustomerID: args.CustomerID, + Page: page, + } + + if args.Params.Filter != nil { + req.TypeFilter = args.Params.Filter.Type + + if args.Params.Filter.Currency != nil { + currency := currencyx.Code(*args.Params.Filter.Currency) + req.CurrencyFilter = ¤cy + } + } + + return req, nil + }, + func(ctx context.Context, request ListCreditTransactionsRequest) (ListCreditTransactionsResponse, error) { + creditMovement, empty := creditMovementFromTypeFilter(request.TypeFilter) + if empty { + return emptyCreditTransactionPage(request.Page), nil + } + + accountID, err := h.customerFBOAccountID(ctx, customer.CustomerID{ + Namespace: request.Namespace, + ID: request.CustomerID, + }) + if err != nil { + return ListCreditTransactionsResponse{}, fmt.Errorf("resolve customer FBO account: %w", err) + } + + if accountID == "" { + return emptyCreditTransactionPage(request.Page), nil + } + + listIn := ledger.ListTransactionsByPageInput{ + Page: request.Page, + Namespace: request.Namespace, + AccountIDs: []string{accountID}, + Currency: request.CurrencyFilter, + CreditMovement: creditMovement, + } + + result, err := h.ledger.ListTransactionsByPage(ctx, listIn) + if err != nil { + return ListCreditTransactionsResponse{}, fmt.Errorf("list transactions: %w", err) + } + + items, err := mapCreditTransactions(result.Items) + if err != nil { + return ListCreditTransactionsResponse{}, err + } + + if len(items) > 0 { + runningBalance, err := h.customerFBOBalance(ctx, request, items[0].Currency, &items[0].Cursor) + if err != nil { + return ListCreditTransactionsResponse{}, fmt.Errorf("get FBO balance after transaction %s: %w", items[0].Cursor.ID.ID, err) + } + + applyCreditTransactionBalances(items, runningBalance) + } + + return response.NewPagePaginationResponse(apiCreditTransactions(items), response.PageMetaPage{ + Size: request.Page.PageSize, + Number: request.Page.PageNumber, + Total: lo.ToPtr(result.TotalCount), + }), nil + }, + commonhttp.JSONResponseEncoderWithStatus[ListCreditTransactionsResponse](http.StatusOK), + httptransport.AppendOptions( + h.options, + httptransport.WithOperationName("list-credit-transactions"), + httptransport.WithErrorEncoder(apierrors.GenericErrorEncoder()), + )..., + ) +} + +func emptyCreditTransactionPage(page pagination.Page) ListCreditTransactionsResponse { + return response.NewPagePaginationResponse([]api.BillingCreditTransaction{}, response.PageMetaPage{ + Size: page.PageSize, + Number: page.PageNumber, + Total: lo.ToPtr(0), + }) +} + +func creditMovementFromTypeFilter(filter *api.BillingCreditTransactionType) (ledger.ListTransactionsCreditMovement, bool) { + if filter == nil { + return ledger.ListTransactionsCreditMovementUnspecified, false + } + + switch *filter { + case api.BillingCreditTransactionTypeFunded: + return ledger.ListTransactionsCreditMovementPositive, false + case api.BillingCreditTransactionTypeConsumed: + return ledger.ListTransactionsCreditMovementNegative, false + case api.BillingCreditTransactionTypeAdjusted: + return ledger.ListTransactionsCreditMovementUnspecified, true + default: + return ledger.ListTransactionsCreditMovementUnspecified, false + } +} + +func (h *handler) customerFBOAccountID(ctx context.Context, customerID customer.CustomerID) (string, error) { + accounts, err := h.accountResolver.GetCustomerAccounts(ctx, customerID) + if err != nil { + return "", err + } + + return fboAccountIDFromCustomerAccounts(accounts), nil +} + +func fboAccountIDFromCustomerAccounts(accounts ledger.CustomerAccounts) string { + if fbo, ok := accounts.FBOAccount.(*ledgeraccount.CustomerFBOAccount); ok { + return fbo.ID().ID + } + + return "" +} + +func (h *handler) customerFBOBalance( + ctx context.Context, + req ListCreditTransactionsRequest, + currency currencyx.Code, + after *ledger.TransactionCursor, +) (alpacadecimal.Decimal, error) { + input := customerbalance.GetBalanceInput{ + CustomerID: customer.CustomerID{ + Namespace: req.Namespace, + ID: req.CustomerID, + }, + Currency: currency, + After: after, + } + + return h.balanceFacade.GetBalance(ctx, input) +} + +func applyCreditTransactionBalances(items []mappedCreditTransaction, after alpacadecimal.Decimal) { + runningBalance := after + + for i := range items { + items[i].API.AvailableBalance.After = runningBalance.String() + items[i].API.AvailableBalance.Before = runningBalance.Sub(items[i].Amount).String() + runningBalance = runningBalance.Sub(items[i].Amount) + } +} + +func mapCreditTransactions(txs []ledger.Transaction) ([]mappedCreditTransaction, error) { + items := make([]mappedCreditTransaction, 0, len(txs)) + + for _, tx := range txs { + item, err := mapCreditTransaction(tx) + if err != nil { + return nil, fmt.Errorf("convert ledger transaction %s: %w", tx.ID().ID, err) + } + + items = append(items, item) + } + + return items, nil +} + +func apiCreditTransactions(items []mappedCreditTransaction) []api.BillingCreditTransaction { + out := make([]api.BillingCreditTransaction, 0, len(items)) + for _, item := range items { + out = append(out, item.API) + } + + return out +} + +// mapCreditTransaction maps a ledger.Transaction to the API BillingCreditTransaction type plus its scoped FBO metadata. +func mapCreditTransaction(tx ledger.Transaction) (mappedCreditTransaction, error) { + entry, err := creditTransactionEntry(tx) + if err != nil { + return mappedCreditTransaction{}, err + } + + createdAt := tx.Cursor().CreatedAt + amount := entry.Amount() + currency := entry.PostingAddress().Route().Route().Currency + txType := creditTransactionType(amount) + + apiTx := api.BillingCreditTransaction{ + Id: tx.ID().ID, + CreatedAt: &createdAt, + BookedAt: tx.BookedAt(), + Type: txType, + Currency: api.BillingCurrencyCode(currency), + Amount: amount.String(), + Name: creditTransactionName(tx), + } + + labels := creditTransactionLabels(tx) + if len(labels) > 0 { + apiLabels := api.Labels(labels) + apiTx.Labels = &apiLabels + } + + return mappedCreditTransaction{ + API: apiTx, + Amount: amount, + Currency: currency, + Cursor: tx.Cursor(), + }, nil +} + +func creditTransactionEntry(tx ledger.Transaction) (ledger.Entry, error) { + for _, entry := range tx.Entries() { + if entry.PostingAddress().AccountType() != ledger.AccountTypeCustomerFBO { + continue + } + + return entry, nil + } + + return nil, fmt.Errorf("no customer FBO entry found in transaction %s", tx.ID().ID) +} + +// creditTransactionType determines the type based on the FBO impact sign. +// Positive = funded (balance went up), negative = consumed (balance went down). +func creditTransactionType(fboImpact alpacadecimal.Decimal) api.BillingCreditTransactionType { + if fboImpact.IsPositive() { + return api.BillingCreditTransactionTypeFunded + } + + if fboImpact.IsNegative() { + return api.BillingCreditTransactionTypeConsumed + } + + return api.BillingCreditTransactionTypeAdjusted +} + +func creditTransactionName(tx ledger.Transaction) string { + templateName, _ := ledger.TransactionTemplateNameFromAnnotations(tx.Annotations()) + if templateName != "" { + return templateName + } + + return "credit_transaction" +} + +func creditTransactionLabels(tx ledger.Transaction) map[string]string { + annotations := tx.Annotations() + labels := make(map[string]string) + + setLabel := func(key, annotationKey string) { + value := stringAnnotation(annotations, annotationKey) + if value != "" { + labels[key] = value + } + } + + setLabel("charge_id", ledger.AnnotationChargeID) + setLabel("subscription_id", ledger.AnnotationSubscriptionID) + setLabel("subscription_phase_id", ledger.AnnotationSubscriptionPhaseID) + setLabel("subscription_item_id", ledger.AnnotationSubscriptionItemID) + setLabel("feature_id", ledger.AnnotationFeatureID) + + return labels +} + +func stringAnnotation(annotations map[string]any, key string) string { + raw, ok := annotations[key] + if !ok { + return "" + } + + value, ok := raw.(string) + if !ok { + return "" + } + + return value +} diff --git a/api/v3/handlers/customers/credits/list_transactions_test.go b/api/v3/handlers/customers/credits/list_transactions_test.go new file mode 100644 index 0000000000..17b280f074 --- /dev/null +++ b/api/v3/handlers/customers/credits/list_transactions_test.go @@ -0,0 +1,195 @@ +package customerscredits + +import ( + "context" + "testing" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/stretchr/testify/require" + + api "github.com/openmeterio/openmeter/api/v3" + "github.com/openmeterio/openmeter/openmeter/ledger" + ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" + ledgerhistorical "github.com/openmeterio/openmeter/openmeter/ledger/historical" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/models" +) + +func TestCreditMovementFromTypeFilter_AdjustedReturnsEmpty(t *testing.T) { + filter := api.BillingCreditTransactionTypeAdjusted + + movement, empty := creditMovementFromTypeFilter(&filter) + + require.Equal(t, ledger.ListTransactionsCreditMovementUnspecified, movement) + require.True(t, empty) +} + +func TestFBOAccountIDFromCustomerAccounts_ReturnsOnlyFBO(t *testing.T) { + fbo := mustCustomerFBOAccount(t, "ns", "fbo-account") + receivable := mustCustomerReceivableAccount(t, "ns", "receivable-account") + accrued := mustCustomerAccruedAccount(t, "ns", "accrued-account") + + accountID := fboAccountIDFromCustomerAccounts(ledger.CustomerAccounts{ + FBOAccount: fbo, + ReceivableAccount: receivable, + AccruedAccount: accrued, + }) + + require.Equal(t, "fbo-account", accountID) +} + +func TestCustomerFBOBalance_UsesCurrencyAndCursor(t *testing.T) { + usd := currencyx.Code("USD") + cursor := &ledger.TransactionCursor{ + BookedAt: time.Date(2026, 4, 10, 9, 0, 0, 0, time.UTC), + CreatedAt: time.Date(2026, 4, 10, 9, 0, 1, 0, time.UTC), + ID: models.NamespacedID{ + Namespace: "ns", + ID: "tx-1", + }, + } + facade := &capturingBalanceFacade{ + balance: alpacadecimal.NewFromInt(42), + } + h := handler{ + balanceFacade: facade, + } + + total, err := h.customerFBOBalance(t.Context(), ListCreditTransactionsRequest{ + Namespace: "ns", + CustomerID: "customer-1", + CurrencyFilter: &usd, + }, usd, cursor) + require.NoError(t, err) + require.True(t, total.Equal(alpacadecimal.NewFromInt(42))) + require.Equal(t, usd, facade.lastBalanceInput.Currency) + require.Equal(t, cursor, facade.lastBalanceInput.After) +} + +func TestMapCreditTransaction_UsesFBOEntry(t *testing.T) { + usd := currencyx.Code("USD") + tx := mustHistoricalTransaction(t, []ledgerhistorical.EntryData{ + mustEntryData(t, "entry-usd", ledger.AccountTypeCustomerFBO, usd, alpacadecimal.NewFromInt(-10)), + mustEntryData(t, "entry-accrued", ledger.AccountTypeCustomerAccrued, usd, alpacadecimal.NewFromInt(10)), + }) + + item, err := mapCreditTransaction(tx) + require.NoError(t, err) + require.Equal(t, api.BillingCreditTransactionTypeConsumed, item.API.Type) + require.Equal(t, api.BillingCurrencyCode("USD"), item.API.Currency) + require.Equal(t, api.Numeric("-10"), item.API.Amount) + require.True(t, item.Amount.Equal(alpacadecimal.NewFromInt(-10))) +} + +func TestApplyCreditTransactionBalances(t *testing.T) { + items := []mappedCreditTransaction{ + { + API: api.BillingCreditTransaction{ + Amount: api.Numeric("-10"), + }, + Amount: alpacadecimal.NewFromInt(-10), + }, + } + + applyCreditTransactionBalances(items, alpacadecimal.NewFromInt(42)) + + require.Equal(t, api.Numeric("42"), items[0].API.AvailableBalance.After) + require.Equal(t, api.Numeric("52"), items[0].API.AvailableBalance.Before) +} + +type capturingBalanceFacade struct { + lastBalanceInput customerbalance.GetBalanceInput + balance alpacadecimal.Decimal +} + +func (c *capturingBalanceFacade) GetBalance(_ context.Context, input customerbalance.GetBalanceInput) (alpacadecimal.Decimal, error) { + c.lastBalanceInput = input + return c.balance, nil +} + +func (c *capturingBalanceFacade) GetBalances(_ context.Context, _ customerbalance.GetBalancesInput) ([]customerbalance.BalanceByCurrency, error) { + return nil, nil +} + +func mustCustomerFBOAccount(t *testing.T, namespace, id string) *ledgeraccount.CustomerFBOAccount { + t.Helper() + + account := mustAccount(t, namespace, id, ledger.AccountTypeCustomerFBO) + fbo, err := account.AsCustomerFBOAccount() + require.NoError(t, err) + + return fbo +} + +func mustCustomerReceivableAccount(t *testing.T, namespace, id string) *ledgeraccount.CustomerReceivableAccount { + t.Helper() + + account := mustAccount(t, namespace, id, ledger.AccountTypeCustomerReceivable) + receivable, err := account.AsCustomerReceivableAccount() + require.NoError(t, err) + + return receivable +} + +func mustCustomerAccruedAccount(t *testing.T, namespace, id string) *ledgeraccount.CustomerAccruedAccount { + t.Helper() + + account := mustAccount(t, namespace, id, ledger.AccountTypeCustomerAccrued) + accrued, err := account.AsCustomerAccruedAccount() + require.NoError(t, err) + + return accrued +} + +func mustAccount(t *testing.T, namespace, id string, accountType ledger.AccountType) *ledgeraccount.Account { + t.Helper() + + account, err := ledgeraccount.NewAccountFromData(ledgeraccount.AccountData{ + ID: models.NamespacedID{ + Namespace: namespace, + ID: id, + }, + AccountType: accountType, + }, ledgeraccount.AccountLiveServices{}) + require.NoError(t, err) + + return account +} + +func mustHistoricalTransaction(t *testing.T, entries []ledgerhistorical.EntryData) ledger.Transaction { + t.Helper() + + tx, err := ledgerhistorical.NewTransactionFromData(ledgerhistorical.TransactionData{ + ID: "tx-1", + Namespace: "ns", + CreatedAt: time.Now().UTC(), + BookedAt: time.Now().UTC(), + }, entries) + require.NoError(t, err) + + return tx +} + +func mustEntryData(t *testing.T, id string, accountType ledger.AccountType, currency currencyx.Code, amount alpacadecimal.Decimal) ledgerhistorical.EntryData { + t.Helper() + + route := ledger.Route{Currency: currency} + key, err := ledger.BuildRoutingKey(ledger.RoutingKeyVersionV1, route) + require.NoError(t, err) + + return ledgerhistorical.EntryData{ + ID: id, + Namespace: "ns", + CreatedAt: time.Now().UTC(), + SubAccountID: id + "-subaccount", + AccountType: accountType, + Route: route, + RouteID: id + "-route", + RouteKey: key.Value(), + RouteKeyVer: key.Version(), + Amount: amount, + TransactionID: "tx-1", + } +} diff --git a/api/v3/server/routes.go b/api/v3/server/routes.go index 8d8f57a963..85c7514ad2 100644 --- a/api/v3/server/routes.go +++ b/api/v3/server/routes.go @@ -378,7 +378,15 @@ func (s *Server) UpdateCreditGrantExternalSettlement(w http.ResponseWriter, r *h } func (s *Server) ListCreditTransactions(w http.ResponseWriter, r *http.Request, customerId api.ULID, params api.ListCreditTransactionsParams) { - unimplemented.ListCreditTransactions(w, r, customerId, params) + if s.customersCreditsHandler == nil || s.Ledger == nil { + unimplemented.ListCreditTransactions(w, r, customerId, params) + return + } + + s.customersCreditsHandler.ListCreditTransactions().With(customerscreditshandler.ListCreditTransactionsParams{ + CustomerID: customerId, + Params: params, + }).ServeHTTP(w, r) } // Charges diff --git a/api/v3/server/server.go b/api/v3/server/server.go index d4a305b4a8..b4131e97b4 100644 --- a/api/v3/server/server.go +++ b/api/v3/server/server.go @@ -40,6 +40,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/entitlement" "github.com/openmeterio/openmeter/openmeter/ingest" + "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/openmeter/llmcost" "github.com/openmeterio/openmeter/openmeter/meter" @@ -72,6 +73,8 @@ type Config struct { IngestService ingest.Service CustomerService customer.Service CreditGrantService creditgrant.Service + Ledger ledger.Ledger + AccountResolver ledger.AccountResolver CustomerBalanceFacade *customerbalance.Facade EntitlementService entitlement.Service PlanService plan.Service @@ -217,7 +220,7 @@ func NewServer(config *Config) (*Server, error) { customersBillingHandler := customersbillinghandler.New(resolveNamespace, config.BillingService, config.CustomerService, config.StripeService, httptransport.WithErrorHandler(config.ErrorHandler)) var customersCreditsHandler customerscreditshandler.Handler if config.CustomerBalanceFacade != nil && config.Credits.Enabled { - customersCreditsHandler = customerscreditshandler.New(resolveNamespace, config.CustomerService, config.CustomerBalanceFacade, config.CreditGrantService, httptransport.WithErrorHandler(config.ErrorHandler)) + customersCreditsHandler = customerscreditshandler.New(resolveNamespace, config.CustomerService, config.CustomerBalanceFacade, config.CreditGrantService, config.Ledger, config.AccountResolver, httptransport.WithErrorHandler(config.ErrorHandler)) } customersEntitlementHandler := customersentitlementhandler.New(resolveNamespace, config.CustomerService, config.EntitlementService, httptransport.WithErrorHandler(config.ErrorHandler)) metersHandler := metershandler.New(resolveNamespace, config.MeterService, config.StreamingConnector, config.CustomerService, httptransport.WithErrorHandler(config.ErrorHandler)) diff --git a/cmd/server/main.go b/cmd/server/main.go index 2d0a1c5351..3d3ab12186 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -158,6 +158,8 @@ func main() { CurrencyService: app.CurrencyService, CostService: app.CostService, CreditGrantService: app.CreditGrantService, + Ledger: app.Ledger, + AccountResolver: app.AccountResolver, Customer: app.Customer, CustomerBalanceFacade: app.CustomerBalanceFacade, DebugConnector: debugConnector, diff --git a/cmd/server/wire.go b/cmd/server/wire.go index a073bc5821..e4544e7c90 100644 --- a/cmd/server/wire.go +++ b/cmd/server/wire.go @@ -21,6 +21,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/ingest" "github.com/openmeterio/openmeter/openmeter/ingest/kafkaingest" + "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/openmeter/llmcost" "github.com/openmeterio/openmeter/openmeter/meter" @@ -59,6 +60,8 @@ type Application struct { CurrencyService currencies.CurrencyService CostService cost.Service CreditGrantService creditgrant.Service + Ledger ledger.Ledger + AccountResolver ledger.AccountResolver CustomerBalanceFacade *customerbalance.Facade EntClient *db.Client EventPublisher eventbus.Publisher diff --git a/cmd/server/wire_gen.go b/cmd/server/wire_gen.go index 24672563ed..12ac68e152 100644 --- a/cmd/server/wire_gen.go +++ b/cmd/server/wire_gen.go @@ -18,6 +18,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/ingest" "github.com/openmeterio/openmeter/openmeter/ingest/kafkaingest" + "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/openmeter/llmcost" "github.com/openmeterio/openmeter/openmeter/meter" @@ -744,6 +745,8 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl CurrencyService: currencyService, CostService: costService, CreditGrantService: creditgrantService, + Ledger: ledger, + AccountResolver: accountResolver, CustomerBalanceFacade: facade, EntClient: client, EventPublisher: eventbusPublisher, @@ -811,6 +814,8 @@ type Application struct { CurrencyService currencies.CurrencyService CostService cost.Service CreditGrantService creditgrant.Service + Ledger ledger.Ledger + AccountResolver ledger.AccountResolver CustomerBalanceFacade *customerbalance.Facade EntClient *db.Client EventPublisher eventbus.Publisher diff --git a/openmeter/ledger/account/account.go b/openmeter/ledger/account/account.go index 93fc9c4347..ce5b45098c 100644 --- a/openmeter/ledger/account/account.go +++ b/openmeter/ledger/account/account.go @@ -71,7 +71,7 @@ type Account struct { var _ ledger.Account = (*Account)(nil) -func (a *Account) GetBalance(ctx context.Context, query ledger.RouteFilter) (ledger.Balance, error) { +func (a *Account) GetBalance(ctx context.Context, query ledger.RouteFilter, after *ledger.TransactionCursor) (ledger.Balance, error) { // TODO: this is a hack // package boundary between account and historical ledger is incorrect, dependency resolution is broken if a.services.Querier == nil { @@ -88,6 +88,7 @@ func (a *Account) GetBalance(ctx context.Context, query ledger.RouteFilter) (led Cursor: lastClosingCursor, Filters: ledger.Filters{ BookedAtPeriod: periodSinceListClosing, + After: after, AccountID: lo.ToPtr(a.data.ID.ID), Route: query, }, diff --git a/openmeter/ledger/account/subaccount.go b/openmeter/ledger/account/subaccount.go index 6747d08d7d..c3480ccc3b 100644 --- a/openmeter/ledger/account/subaccount.go +++ b/openmeter/ledger/account/subaccount.go @@ -77,7 +77,7 @@ func (s *SubAccount) GetBalance(ctx context.Context) (ledger.Balance, error) { return nil, fmt.Errorf("parent account is required") } - res, err := s.account.GetBalance(ctx, s.data.Route.Filter()) + res, err := s.account.GetBalance(ctx, s.data.Route.Filter(), nil) if err != nil { return nil, fmt.Errorf("failed to get balance for sub-account %s: %w", s.data.ID, err) } diff --git a/openmeter/ledger/annotations.go b/openmeter/ledger/annotations.go index fc93b6e81b..fdf3ffb354 100644 --- a/openmeter/ledger/annotations.go +++ b/openmeter/ledger/annotations.go @@ -7,13 +7,26 @@ import ( ) const ( - AnnotationChargeNamespace = "ledger.charge.namespace" - AnnotationChargeID = "ledger.charge.id" + AnnotationChargeNamespace = "ledger.charge.namespace" + AnnotationChargeID = "ledger.charge.id" + AnnotationSubscriptionID = "ledger.subscription.id" + AnnotationSubscriptionPhaseID = "ledger.subscription.phase.id" + AnnotationSubscriptionItemID = "ledger.subscription.item.id" + AnnotationFeatureID = "ledger.feature.id" AnnotationTransactionTemplateName = "ledger.transaction.template_name" AnnotationTransactionDirection = "ledger.transaction.direction" ) +type ChargeTransactionAnnotationsInput struct { + ChargeID models.NamespacedID + + SubscriptionID *string + SubscriptionPhaseID *string + SubscriptionItemID *string + FeatureID *string +} + type TransactionDirection string const ( @@ -28,6 +41,28 @@ func ChargeAnnotations(chargeID models.NamespacedID) models.Annotations { } } +func ChargeTransactionAnnotations(input ChargeTransactionAnnotationsInput) models.Annotations { + annotations := ChargeAnnotations(input.ChargeID) + + if input.SubscriptionID != nil && *input.SubscriptionID != "" { + annotations[AnnotationSubscriptionID] = *input.SubscriptionID + } + + if input.SubscriptionPhaseID != nil && *input.SubscriptionPhaseID != "" { + annotations[AnnotationSubscriptionPhaseID] = *input.SubscriptionPhaseID + } + + if input.SubscriptionItemID != nil && *input.SubscriptionItemID != "" { + annotations[AnnotationSubscriptionItemID] = *input.SubscriptionItemID + } + + if input.FeatureID != nil && *input.FeatureID != "" { + annotations[AnnotationFeatureID] = *input.FeatureID + } + + return annotations +} + func TransactionAnnotations(templateName string, direction TransactionDirection) models.Annotations { return models.Annotations{ AnnotationTransactionTemplateName: templateName, diff --git a/openmeter/ledger/chargeadapter/annotations.go b/openmeter/ledger/chargeadapter/annotations.go new file mode 100644 index 0000000000..4fcb01f267 --- /dev/null +++ b/openmeter/ledger/chargeadapter/annotations.go @@ -0,0 +1,71 @@ +package chargeadapter + +import ( + chargecreditpurchase "github.com/openmeterio/openmeter/openmeter/billing/charges/creditpurchase" + chargeflatfee "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" + "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + chargeusagebased "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" + "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/pkg/models" +) + +func chargeAnnotationsForCreditPurchaseCharge(charge chargecreditpurchase.Charge) models.Annotations { + return chargeTransactionAnnotations( + models.NamespacedID{ + Namespace: charge.Namespace, + ID: charge.ID, + }, + charge.Intent.Subscription, + nil, + ) +} + +func chargeAnnotationsForFlatFeeCharge(charge chargeflatfee.Charge) models.Annotations { + return chargeTransactionAnnotations( + models.NamespacedID{ + Namespace: charge.Namespace, + ID: charge.ID, + }, + charge.Intent.Subscription, + charge.State.FeatureID, + ) +} + +func chargeAnnotationsForUsageBasedCharge(charge chargeusagebased.Charge) models.Annotations { + return chargeTransactionAnnotations( + models.NamespacedID{ + Namespace: charge.Namespace, + ID: charge.ID, + }, + charge.Intent.Subscription, + ptrIfNotEmpty(charge.State.FeatureID), + ) +} + +func chargeTransactionAnnotations(chargeID models.NamespacedID, subscription *meta.SubscriptionReference, featureID *string) models.Annotations { + var subscriptionID *string + var subscriptionPhaseID *string + var subscriptionItemID *string + + if subscription != nil { + subscriptionID = &subscription.SubscriptionID + subscriptionPhaseID = &subscription.PhaseID + subscriptionItemID = &subscription.ItemID + } + + return ledger.ChargeTransactionAnnotations(ledger.ChargeTransactionAnnotationsInput{ + ChargeID: chargeID, + SubscriptionID: subscriptionID, + SubscriptionPhaseID: subscriptionPhaseID, + SubscriptionItemID: subscriptionItemID, + FeatureID: featureID, + }) +} + +func ptrIfNotEmpty(value string) *string { + if value == "" { + return nil + } + + return &value +} diff --git a/openmeter/ledger/chargeadapter/creditpurchase.go b/openmeter/ledger/chargeadapter/creditpurchase.go index d4d38b3891..a88019bb35 100644 --- a/openmeter/ledger/chargeadapter/creditpurchase.go +++ b/openmeter/ledger/chargeadapter/creditpurchase.go @@ -13,7 +13,6 @@ import ( ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" "github.com/openmeterio/openmeter/openmeter/ledger/transactions" "github.com/openmeterio/openmeter/pkg/currencyx" - "github.com/openmeterio/openmeter/pkg/models" ) // creditPurchaseHandler maps credit purchase lifecycle events to ledger transaction templates. @@ -59,10 +58,7 @@ func (h *creditPurchaseHandler) OnCreditPurchasePaymentAuthorized(ctx context.Co Namespace: charge.Namespace, ID: charge.Intent.CustomerID, } - annotations := ledger.ChargeAnnotations(models.NamespacedID{ - Namespace: charge.Namespace, - ID: charge.ID, - }) + annotations := chargeAnnotationsForCreditPurchaseCharge(charge) inputs, err := transactions.ResolveTransactions( ctx, @@ -82,6 +78,12 @@ func (h *creditPurchaseHandler) OnCreditPurchasePaymentAuthorized(ctx context.Co return ledgertransaction.GroupReference{}, fmt.Errorf("resolve transactions: %w", err) } + for i, input := range inputs { + if input != nil { + inputs[i] = transactions.WithAnnotations(input, annotations) + } + } + transactionGroup, err := h.ledger.CommitGroup(ctx, transactions.GroupInputs( charge.Namespace, annotations, @@ -110,10 +112,7 @@ func (h *creditPurchaseHandler) OnCreditPurchasePaymentSettled(ctx context.Conte Namespace: charge.Namespace, ID: charge.Intent.CustomerID, } - annotations := ledger.ChargeAnnotations(models.NamespacedID{ - Namespace: charge.Namespace, - ID: charge.ID, - }) + annotations := chargeAnnotationsForCreditPurchaseCharge(charge) inputs, err := transactions.ResolveTransactions( ctx, @@ -133,6 +132,12 @@ func (h *creditPurchaseHandler) OnCreditPurchasePaymentSettled(ctx context.Conte return ledgertransaction.GroupReference{}, fmt.Errorf("resolve transactions: %w", err) } + for i, input := range inputs { + if input != nil { + inputs[i] = transactions.WithAnnotations(input, annotations) + } + } + transactionGroup, err := h.ledger.CommitGroup(ctx, transactions.GroupInputs( charge.Namespace, annotations, @@ -168,10 +173,7 @@ func (h *creditPurchaseHandler) issueCreditPurchase(ctx context.Context, charge Namespace: charge.Namespace, ID: charge.Intent.CustomerID, } - annotations := ledger.ChargeAnnotations(models.NamespacedID{ - Namespace: charge.Namespace, - ID: charge.ID, - }) + annotations := chargeAnnotationsForCreditPurchaseCharge(charge) advanceOutstanding, err := h.outstandingAdvanceBalance(ctx, customerID, charge.Intent.Currency) if err != nil { @@ -246,6 +248,12 @@ func (h *creditPurchaseHandler) issueCreditPurchase(ctx context.Context, charge return ledgertransaction.GroupReference{}, fmt.Errorf("resolve transactions: %w", err) } + for i, input := range inputs { + if input != nil { + inputs[i] = transactions.WithAnnotations(input, annotations) + } + } + transactionGroup, err := h.ledger.CommitGroup(ctx, transactions.GroupInputs( charge.Namespace, annotations, diff --git a/openmeter/ledger/chargeadapter/creditpurchase_test.go b/openmeter/ledger/chargeadapter/creditpurchase_test.go index 6076d58cbf..05afbd36fd 100644 --- a/openmeter/ledger/chargeadapter/creditpurchase_test.go +++ b/openmeter/ledger/chargeadapter/creditpurchase_test.go @@ -10,6 +10,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing" chargecreditpurchase "github.com/openmeterio/openmeter/openmeter/billing/charges/creditpurchase" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" + ledgertransactiondb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgertransaction" ledgertransactiongroupdb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgertransactiongroup" "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/openmeter/ledger/chargeadapter" @@ -50,6 +51,41 @@ func TestOnCreditPurchaseInitiated(t *testing.T) { require.True(t, env.sumBalance(t, env.receivableSubAccount(t, costBasis)).Equal(alpacadecimal.NewFromInt(-100))) } +func TestOnCreditPurchaseInitiated_TracksChargeReferencesOnTransactions(t *testing.T) { + env := newCreditPurchaseHandlerTestEnv(t) + + costBasis := mustDecimal(t, "0.5") + charge := env.newExternalCharge(alpacadecimal.NewFromInt(100), costBasis) + charge.Intent.Subscription = &meta.SubscriptionReference{ + SubscriptionID: "subscription-01JABCDEF0123456789ABCDEF", + PhaseID: "phase-01JABCDEF0123456789ABCDEF", + ItemID: "item-01JABCDEF0123456789ABCDEF", + } + + ref, err := env.handler.OnCreditPurchaseInitiated(t.Context(), charge) + require.NoError(t, err) + require.NotEmpty(t, ref.TransactionGroupID) + + expected := ledger.ChargeTransactionAnnotations(ledger.ChargeTransactionAnnotationsInput{ + ChargeID: models.NamespacedID{ + Namespace: env.Namespace, + ID: charge.ID, + }, + SubscriptionID: &charge.Intent.Subscription.SubscriptionID, + SubscriptionPhaseID: &charge.Intent.Subscription.PhaseID, + SubscriptionItemID: &charge.Intent.Subscription.ItemID, + }) + require.Equal(t, expected, env.transactionGroupAnnotations(t, ref.TransactionGroupID)) + + for _, annotations := range env.transactionAnnotations(t, ref.TransactionGroupID) { + require.Equal(t, charge.ID, annotations[ledger.AnnotationChargeID]) + require.Equal(t, env.Namespace, annotations[ledger.AnnotationChargeNamespace]) + require.Equal(t, charge.Intent.Subscription.SubscriptionID, annotations[ledger.AnnotationSubscriptionID]) + require.Equal(t, charge.Intent.Subscription.PhaseID, annotations[ledger.AnnotationSubscriptionPhaseID]) + require.Equal(t, charge.Intent.Subscription.ItemID, annotations[ledger.AnnotationSubscriptionItemID]) + } +} + func TestOnCreditPurchaseInitiated_OnlyIssuesExcessBeyondAdvance(t *testing.T) { env := newCreditPurchaseHandlerTestEnv(t) env.createAdvanceExposure(t, alpacadecimal.NewFromInt(40)) @@ -366,6 +402,29 @@ func (e *creditPurchaseHandlerTestEnv) transactionGroupAnnotations(t *testing.T, return group.Annotations } +func (e *creditPurchaseHandlerTestEnv) transactionAnnotations(t *testing.T, groupID string) []models.Annotations { + t.Helper() + + transactions, err := e.DB.LedgerTransaction.Query(). + Where( + ledgertransactiondb.Namespace(e.Namespace), + ledgertransactiondb.GroupID(groupID), + ). + Order( + ledgertransactiondb.ByCreatedAt(), + ledgertransactiondb.ByID(), + ). + All(t.Context()) + require.NoError(t, err) + + out := make([]models.Annotations, 0, len(transactions)) + for _, tx := range transactions { + out = append(out, tx.Annotations) + } + + return out +} + func mustDecimal(t *testing.T, raw string) alpacadecimal.Decimal { t.Helper() diff --git a/openmeter/ledger/chargeadapter/flatfee.go b/openmeter/ledger/chargeadapter/flatfee.go index 6011c0c876..1e7d5865c0 100644 --- a/openmeter/ledger/chargeadapter/flatfee.go +++ b/openmeter/ledger/chargeadapter/flatfee.go @@ -15,7 +15,6 @@ import ( "github.com/openmeterio/openmeter/openmeter/ledger/collector" "github.com/openmeterio/openmeter/openmeter/ledger/transactions" "github.com/openmeterio/openmeter/openmeter/productcatalog" - "github.com/openmeterio/openmeter/pkg/models" ) // flatFeeHandler maps charge lifecycle events to ledger transaction templates @@ -67,6 +66,7 @@ func (h *flatFeeHandler) OnAssignedToInvoice(ctx context.Context, input flatfee. Namespace: input.Charge.Namespace, ChargeID: input.Charge.ID, CustomerID: input.Charge.Intent.CustomerID, + Annotations: chargeAnnotationsForFlatFeeCharge(input.Charge), At: input.Charge.Intent.InvoiceAt, Currency: input.Charge.Intent.Currency, SettlementMode: input.Charge.Intent.SettlementMode, @@ -108,10 +108,7 @@ func (h *flatFeeHandler) OnInvoiceUsageAccrued(ctx context.Context, input flatfe Namespace: input.Charge.Namespace, ID: input.Charge.Intent.CustomerID, } - annotations := ledger.ChargeAnnotations(models.NamespacedID{ - Namespace: input.Charge.Namespace, - ID: input.Charge.ID, - }) + annotations := chargeAnnotationsForFlatFeeCharge(input.Charge) inputs, err := transactions.ResolveTransactions( ctx, @@ -131,6 +128,12 @@ func (h *flatFeeHandler) OnInvoiceUsageAccrued(ctx context.Context, input flatfe return ledgertransaction.GroupReference{}, fmt.Errorf("resolve transactions: %w", err) } + for i, txInput := range inputs { + if txInput != nil { + inputs[i] = transactions.WithAnnotations(txInput, annotations) + } + } + transactionGroup, err := h.ledger.CommitGroup(ctx, transactions.GroupInputs( input.Charge.Namespace, annotations, @@ -164,6 +167,7 @@ func (h *flatFeeHandler) OnCreditsOnlyUsageAccrued(ctx context.Context, input fl Namespace: input.Charge.Namespace, ChargeID: input.Charge.ID, CustomerID: input.Charge.Intent.CustomerID, + Annotations: chargeAnnotationsForFlatFeeCharge(input.Charge), At: input.Charge.Intent.InvoiceAt, Currency: input.Charge.Intent.Currency, SettlementMode: input.Charge.Intent.SettlementMode, @@ -194,6 +198,7 @@ func (h *flatFeeHandler) OnCreditsOnlyUsageAccruedCorrection(ctx context.Context Namespace: input.Charge.Namespace, ChargeID: input.Charge.ID, CustomerID: input.Charge.Intent.CustomerID, + Annotations: chargeAnnotationsForFlatFeeCharge(input.Charge), AllocateAt: input.AllocateAt, Corrections: input.Corrections, LineageSegmentsByRealization: input.LineageSegmentsByRealization, @@ -220,10 +225,7 @@ func (h *flatFeeHandler) OnPaymentAuthorized(ctx context.Context, charge flatfee Namespace: charge.Namespace, ID: charge.Intent.CustomerID, } - annotations := ledger.ChargeAnnotations(models.NamespacedID{ - Namespace: charge.Namespace, - ID: charge.ID, - }) + annotations := chargeAnnotationsForFlatFeeCharge(charge) inputs, err := transactions.ResolveTransactions( ctx, @@ -243,6 +245,12 @@ func (h *flatFeeHandler) OnPaymentAuthorized(ctx context.Context, charge flatfee return ledgertransaction.GroupReference{}, fmt.Errorf("resolve transactions: %w", err) } + for i, txInput := range inputs { + if txInput != nil { + inputs[i] = transactions.WithAnnotations(txInput, annotations) + } + } + transactionGroup, err := h.ledger.CommitGroup(ctx, transactions.GroupInputs( charge.Namespace, annotations, @@ -270,10 +278,7 @@ func (h *flatFeeHandler) OnPaymentSettled(ctx context.Context, charge flatfee.Ch Namespace: charge.Namespace, ID: charge.Intent.CustomerID, } - annotations := ledger.ChargeAnnotations(models.NamespacedID{ - Namespace: charge.Namespace, - ID: charge.ID, - }) + annotations := chargeAnnotationsForFlatFeeCharge(charge) inputs, err := transactions.ResolveTransactions( ctx, @@ -293,6 +298,12 @@ func (h *flatFeeHandler) OnPaymentSettled(ctx context.Context, charge flatfee.Ch return ledgertransaction.GroupReference{}, fmt.Errorf("resolve transactions: %w", err) } + for i, txInput := range inputs { + if txInput != nil { + inputs[i] = transactions.WithAnnotations(txInput, annotations) + } + } + transactionGroup, err := h.ledger.CommitGroup(ctx, transactions.GroupInputs( charge.Namespace, annotations, diff --git a/openmeter/ledger/chargeadapter/usagebased.go b/openmeter/ledger/chargeadapter/usagebased.go index 992f15bd7d..001cda6839 100644 --- a/openmeter/ledger/chargeadapter/usagebased.go +++ b/openmeter/ledger/chargeadapter/usagebased.go @@ -116,6 +116,7 @@ func (h *usageBasedHandler) OnCreditsOnlyUsageAccrued(ctx context.Context, input Namespace: input.Charge.Namespace, ChargeID: input.Charge.ID, CustomerID: input.Charge.Intent.CustomerID, + Annotations: chargeAnnotationsForUsageBasedCharge(input.Charge), At: input.AllocateAt, Currency: input.Charge.Intent.Currency, SettlementMode: input.Charge.Intent.SettlementMode, @@ -166,6 +167,7 @@ func (h *usageBasedHandler) OnCreditsOnlyUsageAccruedCorrection(ctx context.Cont Namespace: input.Charge.Namespace, ChargeID: input.Charge.ID, CustomerID: input.Charge.Intent.CustomerID, + Annotations: chargeAnnotationsForUsageBasedCharge(input.Charge), AllocateAt: input.AllocateAt, Corrections: input.Corrections, LineageSegmentsByRealization: input.LineageSegmentsByRealization, diff --git a/openmeter/ledger/chargeadapter/usagebased_test.go b/openmeter/ledger/chargeadapter/usagebased_test.go index 512544b101..d94bf91cfc 100644 --- a/openmeter/ledger/chargeadapter/usagebased_test.go +++ b/openmeter/ledger/chargeadapter/usagebased_test.go @@ -13,6 +13,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing/charges/models/creditrealization" chargeusagebased "github.com/openmeterio/openmeter/openmeter/billing/charges/usagebased" "github.com/openmeterio/openmeter/openmeter/billing/models/totals" + ledgertransactiondb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgertransaction" "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/openmeter/ledger/chargeadapter" ledgercollector "github.com/openmeterio/openmeter/openmeter/ledger/collector" @@ -86,6 +87,37 @@ func TestOnUsageBasedCreditsOnlyUsageAccrued(t *testing.T) { require.True(t, env.sumBalance(t, env.unknownAccruedSubAccount(t)).Equal(alpacadecimal.Zero)) }) + t.Run("tracks charge references on transactions", func(t *testing.T) { + env := newUsageBasedHandlerTestEnv(t) + + charge := env.newCreditsOnlyCharge() + charge.Intent.Subscription = &meta.SubscriptionReference{ + SubscriptionID: "subscription-01JABCDEF0123456789ABCDEF", + PhaseID: "phase-01JABCDEF0123456789ABCDEF", + ItemID: "item-01JABCDEF0123456789ABCDEF", + } + + realizations, err := env.handler.OnCreditsOnlyUsageAccrued(t.Context(), chargeusagebased.CreditsOnlyUsageAccruedInput{ + Charge: charge, + Run: env.newRun(), + AllocateAt: env.Now(), + AmountToAllocate: alpacadecimal.NewFromInt(30), + }) + require.NoError(t, err) + require.Len(t, realizations, 1) + + transactionAnnotations := env.transactionAnnotations(t, realizations[0].LedgerTransaction.TransactionGroupID) + require.NotEmpty(t, transactionAnnotations) + for _, annotations := range transactionAnnotations { + require.Equal(t, charge.ID, annotations[ledger.AnnotationChargeID]) + require.Equal(t, env.Namespace, annotations[ledger.AnnotationChargeNamespace]) + require.Equal(t, charge.Intent.Subscription.SubscriptionID, annotations[ledger.AnnotationSubscriptionID]) + require.Equal(t, charge.Intent.Subscription.PhaseID, annotations[ledger.AnnotationSubscriptionPhaseID]) + require.Equal(t, charge.Intent.Subscription.ItemID, annotations[ledger.AnnotationSubscriptionItemID]) + require.Equal(t, charge.State.FeatureID, annotations[ledger.AnnotationFeatureID]) + } + }) + t.Run("zero amount is rejected by input validation", func(t *testing.T) { env := newUsageBasedHandlerTestEnv(t) @@ -369,6 +401,29 @@ func (e *usageBasedHandlerTestEnv) sumBalance(t *testing.T, subAccount ledger.Su return e.SumBalance(t, subAccount) } +func (e *usageBasedHandlerTestEnv) transactionAnnotations(t *testing.T, groupID string) []models.Annotations { + t.Helper() + + transactions, err := e.DB.LedgerTransaction.Query(). + Where( + ledgertransactiondb.Namespace(e.Namespace), + ledgertransactiondb.GroupID(groupID), + ). + Order( + ledgertransactiondb.ByCreatedAt(), + ledgertransactiondb.ByID(), + ). + All(t.Context()) + require.NoError(t, err) + + out := make([]models.Annotations, 0, len(transactions)) + for _, tx := range transactions { + out = append(out, tx.Annotations) + } + + return out +} + func (e *usageBasedHandlerTestEnv) realizationsFromAllocations(allocations creditrealization.CreateAllocationInputs) creditrealization.Realizations { now := time.Now().UTC() diff --git a/openmeter/ledger/collector/collect.go b/openmeter/ledger/collector/collect.go index e22a2ed1e2..645e633659 100644 --- a/openmeter/ledger/collector/collect.go +++ b/openmeter/ledger/collector/collect.go @@ -48,12 +48,23 @@ func (c *accrualCollector) collect(ctx context.Context, input CollectToAccruedIn return nil, nil } - transactionGroup, err := c.ledger.CommitGroup(ctx, transactions.GroupInputs( - input.Namespace, - ledger.ChargeAnnotations(models.NamespacedID{ + groupAnnotations := input.Annotations + if groupAnnotations == nil { + groupAnnotations = ledger.ChargeAnnotations(models.NamespacedID{ Namespace: input.Namespace, ID: input.ChargeID, - }), + }) + } + + for i, txInput := range inputs { + if txInput != nil { + inputs[i] = transactions.WithAnnotations(txInput, groupAnnotations) + } + } + + transactionGroup, err := c.ledger.CommitGroup(ctx, transactions.GroupInputs( + input.Namespace, + groupAnnotations, inputs..., )) if err != nil { diff --git a/openmeter/ledger/collector/correct.go b/openmeter/ledger/collector/correct.go index ffd33c440e..c1fe4e63be 100644 --- a/openmeter/ledger/collector/correct.go +++ b/openmeter/ledger/collector/correct.go @@ -79,14 +79,25 @@ func (c *accrualCorrector) correct(ctx context.Context, input CorrectCollectedAc return nil, nil } + groupAnnotations := input.Annotations + if groupAnnotations == nil { + groupAnnotations = ledger.ChargeAnnotations(models.NamespacedID{ + Namespace: input.Namespace, + ID: input.ChargeID, + }) + } + // Write the whole correction batch as one group and point every new correction // realization at that group. + for i, txInput := range resolvedInputs { + if txInput != nil { + resolvedInputs[i] = transactions.WithAnnotations(txInput, groupAnnotations) + } + } + transactionGroup, err := c.ledger.CommitGroup(ctx, transactions.GroupInputs( input.Namespace, - ledger.ChargeAnnotations(models.NamespacedID{ - Namespace: input.Namespace, - ID: input.ChargeID, - }), + groupAnnotations, resolvedInputs..., )) if err != nil { diff --git a/openmeter/ledger/collector/service.go b/openmeter/ledger/collector/service.go index 3864e35278..36f8410d50 100644 --- a/openmeter/ledger/collector/service.go +++ b/openmeter/ledger/collector/service.go @@ -12,6 +12,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ledger/transactions" "github.com/openmeterio/openmeter/openmeter/productcatalog" "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/models" "github.com/openmeterio/openmeter/pkg/timeutil" ) @@ -29,6 +30,7 @@ type CollectToAccruedInput struct { Namespace string ChargeID string CustomerID string + Annotations models.Annotations At time.Time Currency currencyx.Code SettlementMode productcatalog.SettlementMode @@ -40,6 +42,7 @@ type CorrectCollectedAccruedInput struct { Namespace string ChargeID string CustomerID string + Annotations models.Annotations AllocateAt time.Time Corrections creditrealization.CorrectionRequest LineageSegmentsByRealization lineage.ActiveSegmentsByRealizationID diff --git a/openmeter/ledger/customerbalance/facade.go b/openmeter/ledger/customerbalance/facade.go index 35cff1af76..dbbb296f0c 100644 --- a/openmeter/ledger/customerbalance/facade.go +++ b/openmeter/ledger/customerbalance/facade.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" + "github.com/alpacahq/alpacadecimal" + "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/pkg/currencyx" @@ -29,6 +31,12 @@ type GetBalancesInput struct { Currencies CurrencyFilter } +type GetBalanceInput struct { + CustomerID customer.CustomerID + Currency currencyx.Code + After *ledger.TransactionCursor +} + func (i GetBalancesInput) Validate() error { var errs []error @@ -43,13 +51,33 @@ func (i GetBalancesInput) Validate() error { return errors.Join(errs...) } +func (i GetBalanceInput) Validate() error { + var errs []error + + if err := i.CustomerID.Validate(); err != nil { + errs = append(errs, fmt.Errorf("customer ID: %w", err)) + } + + if err := i.Currency.Validate(); err != nil { + errs = append(errs, fmt.Errorf("currency: %w", err)) + } + + if i.After != nil { + if err := i.After.Validate(); err != nil { + errs = append(errs, fmt.Errorf("after: %w", err)) + } + } + + return errors.Join(errs...) +} + type BalanceByCurrency struct { Currency currencyx.Code Balance ledger.Balance } type FacadeService interface { - GetBalance(ctx context.Context, customerID customer.CustomerID, filters ledger.RouteFilter) (ledger.Balance, error) + GetBalance(ctx context.Context, customerID customer.CustomerID, filters ledger.RouteFilter, after *ledger.TransactionCursor) (ledger.Balance, error) getFBOCurrencies(ctx context.Context, customerID customer.CustomerID) ([]currencyx.Code, error) } @@ -96,7 +124,7 @@ func (f *Facade) GetBalances(ctx context.Context, input GetBalancesInput) ([]Bal balances := make([]BalanceByCurrency, 0, len(codes)) for _, code := range codes { - balance, err := f.service.GetBalance(ctx, input.CustomerID, routeFilter(code)) + balance, err := f.service.GetBalance(ctx, input.CustomerID, routeFilter(code), nil) if err != nil { return nil, err } @@ -110,6 +138,23 @@ func (f *Facade) GetBalances(ctx context.Context, input GetBalancesInput) ([]Bal return balances, nil } +func (f *Facade) GetBalance(ctx context.Context, input GetBalanceInput) (alpacadecimal.Decimal, error) { + if f == nil { + return alpacadecimal.Zero, errors.New("facade is required") + } + + if err := input.Validate(); err != nil { + return alpacadecimal.Zero, err + } + + balance, err := f.service.GetBalance(ctx, input.CustomerID, routeFilter(input.Currency), input.After) + if err != nil { + return alpacadecimal.Zero, err + } + + return balance.Settled(), nil +} + func routeFilter(currency currencyx.Code) ledger.RouteFilter { return ledger.RouteFilter{ Currency: currency, diff --git a/openmeter/ledger/customerbalance/facade_test.go b/openmeter/ledger/customerbalance/facade_test.go index 4834986dcd..8f96e9ffbd 100644 --- a/openmeter/ledger/customerbalance/facade_test.go +++ b/openmeter/ledger/customerbalance/facade_test.go @@ -2,12 +2,18 @@ package customerbalance import ( "testing" + "time" "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" "github.com/stretchr/testify/require" + "github.com/openmeterio/openmeter/openmeter/ledger" + ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" "github.com/openmeterio/openmeter/openmeter/productcatalog" + "github.com/openmeterio/openmeter/pkg/clock" "github.com/openmeterio/openmeter/pkg/currencyx" + pagepagination "github.com/openmeterio/openmeter/pkg/pagination" ) func TestFacadeGetBalancesWithExplicitCurrencies(t *testing.T) { @@ -89,3 +95,47 @@ func TestFacadeGetBalancesWithUnsupportedExplicitCurrency(t *testing.T) { require.ErrorContains(t, err, "CUSTOM") require.ErrorContains(t, err, "not supported by ledger") } + +func TestFacadeGetBalanceAfterTransactionCursor(t *testing.T) { + env := newTestEnv(t) + facade, err := NewFacade(env.Service) + require.NoError(t, err) + + firstBookedAt := time.Date(2026, 4, 10, 9, 0, 0, 0, time.UTC) + secondBookedAt := firstBookedAt.Add(time.Minute) + + clock.SetTime(firstBookedAt) + defer clock.ResetTime() + env.bookFBOBalance(t, alpacadecimal.NewFromInt(100)) + + clock.SetTime(secondBookedAt) + env.bookFBOBalance(t, alpacadecimal.NewFromInt(20)) + + fboAccount, ok := env.CustomerAccounts.FBOAccount.(*ledgeraccount.CustomerFBOAccount) + require.True(t, ok) + + paged, err := env.Deps.HistoricalLedger.ListTransactionsByPage(t.Context(), ledger.ListTransactionsByPageInput{ + Page: pagepagination.NewPage(1, 10), + Namespace: env.Namespace, + AccountIDs: []string{fboAccount.ID().ID}, + Currency: &env.Currency, + }) + require.NoError(t, err) + require.Len(t, paged.Items, 2) + + olderTx := paged.Items[1] + balanceAfterOlderTx, err := facade.GetBalance(t.Context(), GetBalanceInput{ + CustomerID: env.CustomerID, + Currency: env.Currency, + After: lo.ToPtr(olderTx.Cursor()), + }) + require.NoError(t, err) + require.True(t, balanceAfterOlderTx.Equal(alpacadecimal.NewFromInt(100))) + + currentBalance, err := facade.GetBalance(t.Context(), GetBalanceInput{ + CustomerID: env.CustomerID, + Currency: env.Currency, + }) + require.NoError(t, err) + require.True(t, currentBalance.Equal(alpacadecimal.NewFromInt(120))) +} diff --git a/openmeter/ledger/customerbalance/noop.go b/openmeter/ledger/customerbalance/noop.go index a7628b3ae1..692c7351d6 100644 --- a/openmeter/ledger/customerbalance/noop.go +++ b/openmeter/ledger/customerbalance/noop.go @@ -24,7 +24,7 @@ type NoopService struct{} var _ FacadeService = NoopService{} -func (NoopService) GetBalance(context.Context, customer.CustomerID, ledger.RouteFilter) (ledger.Balance, error) { +func (NoopService) GetBalance(context.Context, customer.CustomerID, ledger.RouteFilter, *ledger.TransactionCursor) (ledger.Balance, error) { return noopBalance{}, nil } diff --git a/openmeter/ledger/customerbalance/service.go b/openmeter/ledger/customerbalance/service.go index ccdce62119..254ab25127 100644 --- a/openmeter/ledger/customerbalance/service.go +++ b/openmeter/ledger/customerbalance/service.go @@ -95,7 +95,7 @@ func New(config Config) (*Service, error) { }, nil } -func (s *Service) GetBalance(ctx context.Context, customerID customer.CustomerID, filters ledger.RouteFilter) (ledger.Balance, error) { +func (s *Service) GetBalance(ctx context.Context, customerID customer.CustomerID, filters ledger.RouteFilter, after *ledger.TransactionCursor) (ledger.Balance, error) { if err := s.validate(customerID, filters); err != nil { return nil, err } @@ -105,11 +105,13 @@ func (s *Service) GetBalance(ctx context.Context, customerID customer.CustomerID return nil, fmt.Errorf("get customer accounts: %w", err) } - bookedBalance, err := customerAccounts.FBOAccount.GetBalance(ctx, filters) + bookedBalance, err := customerAccounts.FBOAccount.GetBalance(ctx, filters, after) if err != nil { return nil, fmt.Errorf("get booked balance: %w", err) } + // Pending balance remains a current projection from open charges. + // Historical cursoring only affects the booked/settled side for now. impacts, err := s.getChargePendingBalanceImpacts(ctx, customerID, filters.Currency) if err != nil { return nil, fmt.Errorf("get charge pending balance impacts: %w", err) diff --git a/openmeter/ledger/customerbalance/service_test.go b/openmeter/ledger/customerbalance/service_test.go index 625fecc2af..ee763d82e3 100644 --- a/openmeter/ledger/customerbalance/service_test.go +++ b/openmeter/ledger/customerbalance/service_test.go @@ -113,7 +113,7 @@ func TestGetBalance(t *testing.T) { balance, err := env.Service.GetBalance(t.Context(), env.CustomerID, ledger.RouteFilter{ Currency: env.Currency, CreditPriority: &priority, - }) + }, nil) require.NoError(t, err) require.True(t, balance.Settled().Equal(alpacadecimal.NewFromInt(tt.wantSettled))) require.True(t, balance.Pending().Equal(alpacadecimal.NewFromInt(tt.wantPending))) @@ -133,7 +133,7 @@ func TestGetBalanceWithDifferentCurrency(t *testing.T) { usdBalance, err := env.Service.GetBalance(t.Context(), env.CustomerID, ledger.RouteFilter{ Currency: currencyx.Code("USD"), CreditPriority: &usdPriority, - }) + }, nil) require.NoError(t, err) require.True(t, usdBalance.Settled().Equal(alpacadecimal.NewFromInt(100))) require.True(t, usdBalance.Pending().Equal(alpacadecimal.NewFromInt(70))) @@ -142,7 +142,7 @@ func TestGetBalanceWithDifferentCurrency(t *testing.T) { eurBalance, err := env.Service.GetBalance(t.Context(), env.CustomerID, ledger.RouteFilter{ Currency: currencyx.Code("EUR"), CreditPriority: &eurPriority, - }) + }, nil) require.NoError(t, err) require.True(t, eurBalance.Settled().Equal(alpacadecimal.NewFromInt(200))) require.True(t, eurBalance.Pending().Equal(alpacadecimal.NewFromInt(130))) diff --git a/openmeter/ledger/historical/adapter/ledger.go b/openmeter/ledger/historical/adapter/ledger.go index 892896598a..d5685afd2b 100644 --- a/openmeter/ledger/historical/adapter/ledger.go +++ b/openmeter/ledger/historical/adapter/ledger.go @@ -5,17 +5,24 @@ import ( stdsql "database/sql" "fmt" + entsql "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqljson" "github.com/alpacahq/alpacadecimal" "github.com/samber/lo" "github.com/openmeterio/openmeter/openmeter/ent/db" + ledgeraccountdb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgeraccount" ledgerentrydb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgerentry" + ledgersubaccountdb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgersubaccount" + ledgersubaccountroutedb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgersubaccountroute" ledgertransactiondb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgertransaction" ledgertransactiongroupdb "github.com/openmeterio/openmeter/openmeter/ent/db/ledgertransactiongroup" + "github.com/openmeterio/openmeter/openmeter/ent/db/predicate" "github.com/openmeterio/openmeter/openmeter/ledger" ledgerhistorical "github.com/openmeterio/openmeter/openmeter/ledger/historical" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" + pagepagination "github.com/openmeterio/openmeter/pkg/pagination" "github.com/openmeterio/openmeter/pkg/pagination/v2" "github.com/openmeterio/openmeter/pkg/slicesx" ) @@ -268,6 +275,24 @@ func (r *repo) ListTransactions(ctx context.Context, input ledger.ListTransactio query = query.Where(ledgertransactiondb.ID(input.TransactionID.ID)) } + // Scope to specific accounts. + if len(input.AccountIDs) > 0 { + query = query.Where( + ledgertransactiondb.HasEntriesWith( + ledgerentrydb.HasSubAccountWith( + ledgersubaccountdb.AccountIDIn(input.AccountIDs...), + ), + ), + ) + } + + // Filter by annotation key-value matches. + for key, value := range input.AnnotationFilters { + query = query.Where(func(s *entsql.Selector) { + s.Where(sqljson.ValueEQ(ledgertransactiondb.FieldAnnotations, value, sqljson.Path(key))) + }) + } + if input.Limit > 0 { query = query.Limit(input.Limit) } @@ -294,3 +319,148 @@ func (r *repo) ListTransactions(ctx context.Context, input ledger.ListTransactio NextCursor: paged.NextCursor, }, nil } + +func (r *repo) ListTransactionsByPage(ctx context.Context, input ledger.ListTransactionsByPageInput) (pagepagination.Result[*ledgerhistorical.Transaction], error) { + entryPredicates := listTransactionsEntryPredicates(input) + + query := r.db.LedgerTransaction.Query(). + Where(ledgertransactiondb.Namespace(input.Namespace)). + WithEntries(func(q *db.LedgerEntryQuery) { + if len(entryPredicates) > 0 { + q.Where(entryPredicates...) + } + q.Order( + ledgerentrydb.ByCreatedAt(), + ledgerentrydb.ByID(), + ) + q.WithSubAccount(func(sq *db.LedgerSubAccountQuery) { + sq.WithAccount() + sq.WithRoute() + }) + }). + Order( + ledgertransactiondb.ByBookedAt(entsql.OrderDesc()), + ledgertransactiondb.ByCreatedAt(entsql.OrderDesc()), + ledgertransactiondb.ByID(entsql.OrderDesc()), + ) + + if len(input.AccountIDs) > 0 { + query = query.Where( + ledgertransactiondb.HasEntriesWith( + ledgerentrydb.HasSubAccountWith( + ledgersubaccountdb.AccountIDIn(input.AccountIDs...), + ), + ), + ) + } + + if input.Currency != nil { + query = query.Where( + ledgertransactiondb.HasEntriesWith( + ledgerentrydb.HasSubAccountWith( + ledgersubaccountdb.HasRouteWith( + ledgersubaccountroutedb.Currency(string(*input.Currency)), + ), + ), + ), + ) + } + + for key, value := range input.AnnotationFilters { + query = query.Where(func(s *entsql.Selector) { + s.Where(sqljson.ValueEQ(ledgertransactiondb.FieldAnnotations, value, sqljson.Path(key))) + }) + } + + if input.CreditMovement != ledger.ListTransactionsCreditMovementUnspecified { + pred, err := ledgerTransactionCreditMovementPredicate(input.AccountIDs, input.Currency, input.CreditMovement) + if err != nil { + return pagepagination.Result[*ledgerhistorical.Transaction]{}, err + } + if pred != nil { + query = query.Where(pred) + } + } + + paged, err := query.Paginate(ctx, input.Page) + if err != nil { + return pagepagination.Result[*ledgerhistorical.Transaction]{}, fmt.Errorf("list transactions by page: %w", err) + } + + items, err := slicesx.MapWithErr(paged.Items, func(tx *db.LedgerTransaction) (*ledgerhistorical.Transaction, error) { + return hydrateHistoricalTransaction(tx) + }) + if err != nil { + return pagepagination.Result[*ledgerhistorical.Transaction]{}, fmt.Errorf("hydrate listed transactions: %w", err) + } + + return pagepagination.Result[*ledgerhistorical.Transaction]{ + Page: paged.Page, + TotalCount: paged.TotalCount, + Items: items, + }, nil +} + +func listTransactionsEntryPredicates(input ledger.ListTransactionsByPageInput) []predicate.LedgerEntry { + entryPredicates := make([]predicate.LedgerEntry, 0, 2) + subAccountPredicates := make([]predicate.LedgerSubAccount, 0, 2) + + if len(input.AccountIDs) > 0 { + subAccountPredicates = append(subAccountPredicates, ledgersubaccountdb.AccountIDIn(input.AccountIDs...)) + } + + if input.Currency != nil { + subAccountPredicates = append(subAccountPredicates, + ledgersubaccountdb.HasRouteWith( + ledgersubaccountroutedb.Currency(string(*input.Currency)), + ), + ) + } + + if len(subAccountPredicates) > 0 { + entryPredicates = append(entryPredicates, ledgerentrydb.HasSubAccountWith(subAccountPredicates...)) + } + + return entryPredicates +} + +func ledgerTransactionCreditMovementPredicate( + accountIDs []string, + currency *currencyx.Code, + movement ledger.ListTransactionsCreditMovement, +) (predicate.LedgerTransaction, error) { + subAccountPredicates := []predicate.LedgerSubAccount{ + ledgersubaccountdb.HasAccountWith( + ledgeraccountdb.AccountType(ledger.AccountTypeCustomerFBO), + ), + } + + if len(accountIDs) > 0 { + subAccountPredicates = append(subAccountPredicates, ledgersubaccountdb.AccountIDIn(accountIDs...)) + } + + if currency != nil { + subAccountPredicates = append(subAccountPredicates, + ledgersubaccountdb.HasRouteWith( + ledgersubaccountroutedb.Currency(string(*currency)), + ), + ) + } + + entryPredicates := []predicate.LedgerEntry{ + ledgerentrydb.HasSubAccountWith(subAccountPredicates...), + } + + switch movement { + case ledger.ListTransactionsCreditMovementPositive: + entryPredicates = append(entryPredicates, ledgerentrydb.AmountGT(alpacadecimal.Zero)) + case ledger.ListTransactionsCreditMovementNegative: + entryPredicates = append(entryPredicates, ledgerentrydb.AmountLT(alpacadecimal.Zero)) + case ledger.ListTransactionsCreditMovementUnspecified: + return nil, nil + default: + return nil, fmt.Errorf("unsupported credit movement filter: %d", movement) + } + + return ledgertransactiondb.HasEntriesWith(entryPredicates...), nil +} diff --git a/openmeter/ledger/historical/adapter/ledger_test.go b/openmeter/ledger/historical/adapter/ledger_test.go index a43bf8ae70..5223a1d4fd 100644 --- a/openmeter/ledger/historical/adapter/ledger_test.go +++ b/openmeter/ledger/historical/adapter/ledger_test.go @@ -22,6 +22,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/testutils" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" + pagepagination "github.com/openmeterio/openmeter/pkg/pagination" "github.com/openmeterio/openmeter/pkg/timeutil" "github.com/openmeterio/openmeter/tools/migrate" ) @@ -224,6 +225,207 @@ func TestRepo_ListTransactions_PaginatesAndFilters(t *testing.T) { require.Equal(t, tx2.ID(), filtered.Items[0].ID()) } +func TestRepo_ListTransactionsByPage_FiltersCreditMovementByScopedFBOEntry(t *testing.T) { + env := NewTestEnv(t) + t.Cleanup(func() { + env.Close(t) + }) + env.DBSchemaMigrate(t) + + ctx := t.Context() + namespace := testNamespace() + usdSubAccount := env.createSubAccount(t, namespace, ledger.Route{Currency: currencyx.Code("USD")}) + eurSubAccount := env.createSubAccount(t, namespace, ledger.Route{Currency: currencyx.Code("EUR")}) + + group, err := env.repo.CreateTransactionGroup(ctx, ledgerhistorical.CreateTransactionGroupInput{ + Namespace: namespace, + }) + require.NoError(t, err) + + txInput := mustSetUpHistoricalTransactionInput(t, time.Now().UTC(), []*transactionstestutils.AnyEntryInput{ + { + Address: testAddress(t, usdSubAccount), + AmountValue: alpacadecimal.NewFromInt(-10), + }, + { + Address: testAddress(t, eurSubAccount), + AmountValue: alpacadecimal.NewFromInt(10), + }, + }) + tx, err := env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, txInput) + require.NoError(t, err) + + page := pagepagination.NewPage(1, 20) + usd := currencyx.Code("USD") + eur := currencyx.Code("EUR") + accountIDs := []string{usdSubAccount.AccountID, eurSubAccount.AccountID} + + usdConsumed, err := env.repo.ListTransactionsByPage(ctx, ledger.ListTransactionsByPageInput{ + Page: page, + Namespace: namespace, + AccountIDs: accountIDs, + Currency: &usd, + CreditMovement: ledger.ListTransactionsCreditMovementNegative, + }) + require.NoError(t, err) + require.Len(t, usdConsumed.Items, 1) + require.Equal(t, tx.ID(), usdConsumed.Items[0].ID()) + require.Len(t, usdConsumed.Items[0].Entries(), 1) + require.Equal(t, currencyx.Code("USD"), usdConsumed.Items[0].Entries()[0].PostingAddress().Route().Route().Currency) + + usdFunded, err := env.repo.ListTransactionsByPage(ctx, ledger.ListTransactionsByPageInput{ + Page: page, + Namespace: namespace, + AccountIDs: accountIDs, + Currency: &usd, + CreditMovement: ledger.ListTransactionsCreditMovementPositive, + }) + require.NoError(t, err) + require.Len(t, usdFunded.Items, 0) + + eurFunded, err := env.repo.ListTransactionsByPage(ctx, ledger.ListTransactionsByPageInput{ + Page: page, + Namespace: namespace, + AccountIDs: accountIDs, + Currency: &eur, + CreditMovement: ledger.ListTransactionsCreditMovementPositive, + }) + require.NoError(t, err) + require.Len(t, eurFunded.Items, 1) + require.Equal(t, tx.ID(), eurFunded.Items[0].ID()) + require.Len(t, eurFunded.Items[0].Entries(), 1) + require.Equal(t, currencyx.Code("EUR"), eurFunded.Items[0].Entries()[0].PostingAddress().Route().Route().Currency) + + eurConsumed, err := env.repo.ListTransactionsByPage(ctx, ledger.ListTransactionsByPageInput{ + Page: page, + Namespace: namespace, + AccountIDs: accountIDs, + Currency: &eur, + CreditMovement: ledger.ListTransactionsCreditMovementNegative, + }) + require.NoError(t, err) + require.Len(t, eurConsumed.Items, 0) +} + +func TestRepo_ListTransactionsByPage_PaginatesAndFiltersByAccountAndAnnotation(t *testing.T) { + env := NewTestEnv(t) + t.Cleanup(func() { + env.Close(t) + }) + env.DBSchemaMigrate(t) + + ctx := t.Context() + namespace := testNamespace() + usdSubAccountA := env.createSubAccount(t, namespace, ledger.Route{Currency: currencyx.Code("USD")}) + eurSubAccount := env.createSubAccount(t, namespace, ledger.Route{Currency: currencyx.Code("EUR")}) + usdSubAccountC := env.createSubAccount(t, namespace, ledger.Route{Currency: currencyx.Code("USD")}) + + group, err := env.repo.CreateTransactionGroup(ctx, ledgerhistorical.CreateTransactionGroupInput{ + Namespace: namespace, + }) + require.NoError(t, err) + + now := time.Now().UTC() + + txOld, err := env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, &transactionstestutils.AnyTransactionInput{ + BookedAtValue: now.Add(-2 * time.Hour), + AnnotationsValue: models.Annotations{ + "kind": "keep", + }, + EntryInputsValues: []*transactionstestutils.AnyEntryInput{ + { + Address: testAddress(t, usdSubAccountA), + AmountValue: alpacadecimal.NewFromInt(-10), + }, + { + Address: testAddress(t, eurSubAccount), + AmountValue: alpacadecimal.NewFromInt(10), + }, + }, + }) + require.NoError(t, err) + + _, err = env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, &transactionstestutils.AnyTransactionInput{ + BookedAtValue: now.Add(-90 * time.Minute), + AnnotationsValue: models.Annotations{ + "kind": "skip", + }, + EntryInputsValues: []*transactionstestutils.AnyEntryInput{ + { + Address: testAddress(t, usdSubAccountA), + AmountValue: alpacadecimal.NewFromInt(-15), + }, + { + Address: testAddress(t, eurSubAccount), + AmountValue: alpacadecimal.NewFromInt(15), + }, + }, + }) + require.NoError(t, err) + + _, err = env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, &transactionstestutils.AnyTransactionInput{ + BookedAtValue: now.Add(-1 * time.Hour), + AnnotationsValue: models.Annotations{ + "kind": "keep", + }, + EntryInputsValues: []*transactionstestutils.AnyEntryInput{ + { + Address: testAddress(t, usdSubAccountC), + AmountValue: alpacadecimal.NewFromInt(-20), + }, + { + Address: testAddress(t, eurSubAccount), + AmountValue: alpacadecimal.NewFromInt(20), + }, + }, + }) + require.NoError(t, err) + + txNew, err := env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, &transactionstestutils.AnyTransactionInput{ + BookedAtValue: now.Add(-30 * time.Minute), + AnnotationsValue: models.Annotations{ + "kind": "keep", + }, + EntryInputsValues: []*transactionstestutils.AnyEntryInput{ + { + Address: testAddress(t, usdSubAccountA), + AmountValue: alpacadecimal.NewFromInt(-30), + }, + { + Address: testAddress(t, eurSubAccount), + AmountValue: alpacadecimal.NewFromInt(30), + }, + }, + }) + require.NoError(t, err) + + page1, err := env.repo.ListTransactionsByPage(ctx, ledger.ListTransactionsByPageInput{ + Page: pagepagination.NewPage(1, 1), + Namespace: namespace, + AccountIDs: []string{usdSubAccountA.AccountID}, + AnnotationFilters: map[string]string{ + "kind": "keep", + }, + }) + require.NoError(t, err) + require.Equal(t, 2, page1.TotalCount) + require.Len(t, page1.Items, 1) + require.Equal(t, txNew.ID(), page1.Items[0].ID()) + + page2, err := env.repo.ListTransactionsByPage(ctx, ledger.ListTransactionsByPageInput{ + Page: pagepagination.NewPage(2, 1), + Namespace: namespace, + AccountIDs: []string{usdSubAccountA.AccountID}, + AnnotationFilters: map[string]string{ + "kind": "keep", + }, + }) + require.NoError(t, err) + require.Equal(t, 2, page2.TotalCount) + require.Len(t, page2.Items, 1) + require.Equal(t, txOld.ID(), page2.Items[0].ID()) +} + func TestRepo_SumEntries_Filters(t *testing.T) { env := NewTestEnv(t) t.Cleanup(func() { @@ -280,7 +482,7 @@ func TestRepo_SumEntries_Filters(t *testing.T) { AmountValue: alpacadecimal.NewFromInt(-50), }, }) - _, err = env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, txInputLate) + txLate, err := env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, txInputLate) require.NoError(t, err) txInputCostBasis := mustSetUpHistoricalTransactionInput(t, time.Now().UTC().Add(-15*time.Minute), []*transactionstestutils.AnyEntryInput{ @@ -293,7 +495,7 @@ func TestRepo_SumEntries_Filters(t *testing.T) { AmountValue: alpacadecimal.NewFromInt(-25), }, }) - _, err = env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, txInputCostBasis) + txCostBasis, err := env.repo.BookTransaction(ctx, models.NamespacedID{Namespace: namespace, ID: group.ID}, txInputCostBasis) require.NoError(t, err) // Sum by currency @@ -360,6 +562,36 @@ func TestRepo_SumEntries_Filters(t *testing.T) { }) require.NoError(t, err) require.True(t, sumCostBasis.Equal(alpacadecimal.NewFromInt(25))) + + sumAfterLate, err := env.repo.SumEntries(ctx, ledger.Query{ + Namespace: namespace, + Filters: ledger.Filters{ + After: lo.ToPtr(txLate.Cursor()), + Route: ledger.RouteFilter{Currency: currencyx.Code("USD")}, + }, + }) + require.NoError(t, err) + require.True(t, sumAfterLate.Equal(alpacadecimal.NewFromInt(50))) + + sumAfterCostBasis, err := env.repo.SumEntries(ctx, ledger.Query{ + Namespace: namespace, + Filters: ledger.Filters{ + After: lo.ToPtr(txCostBasis.Cursor()), + Route: ledger.RouteFilter{Currency: currencyx.Code("USD")}, + }, + }) + require.NoError(t, err) + require.True(t, sumAfterCostBasis.Equal(alpacadecimal.NewFromInt(75))) + + sumAfterEarly, err := env.repo.SumEntries(ctx, ledger.Query{ + Namespace: namespace, + Filters: ledger.Filters{ + After: lo.ToPtr(txEarly.Cursor()), + Route: ledger.RouteFilter{Currency: currencyx.Code("USD")}, + }, + }) + require.NoError(t, err) + require.True(t, sumAfterEarly.Equal(alpacadecimal.NewFromInt(0))) } func TestSumEntriesQuery_SQL(t *testing.T) { diff --git a/openmeter/ledger/historical/adapter/sumentries_query.go b/openmeter/ledger/historical/adapter/sumentries_query.go index 8b3a9181b0..ef9154d634 100644 --- a/openmeter/ledger/historical/adapter/sumentries_query.go +++ b/openmeter/ledger/historical/adapter/sumentries_query.go @@ -68,6 +68,25 @@ func (b *sumEntriesQuery) entryPredicates() ([]predicate.LedgerEntry, error) { } } + if b.query.Filters.After != nil { + after := b.query.Filters.After + entryPredicates = append(entryPredicates, ledgerentrydb.HasTransactionWith(func(s *sql.Selector) { + s.Where(sql.Or( + sql.LT(s.C(ledgertransactiondb.FieldBookedAt), after.BookedAt), + sql.And( + sql.EQ(s.C(ledgertransactiondb.FieldBookedAt), after.BookedAt), + sql.Or( + sql.LT(s.C(ledgertransactiondb.FieldCreatedAt), after.CreatedAt), + sql.And( + sql.EQ(s.C(ledgertransactiondb.FieldCreatedAt), after.CreatedAt), + sql.LTE(s.C(ledgertransactiondb.FieldID), after.ID.ID), + ), + ), + ), + )) + })) + } + subAccountPredicates, err := b.subAccountPredicates() if err != nil { return nil, err diff --git a/openmeter/ledger/historical/ledger.go b/openmeter/ledger/historical/ledger.go index 2a2f0a0b34..1615f8fb92 100644 --- a/openmeter/ledger/historical/ledger.go +++ b/openmeter/ledger/historical/ledger.go @@ -11,6 +11,7 @@ import ( "github.com/openmeterio/openmeter/pkg/framework/lockr" "github.com/openmeterio/openmeter/pkg/framework/transaction" "github.com/openmeterio/openmeter/pkg/models" + pagepagination "github.com/openmeterio/openmeter/pkg/pagination" "github.com/openmeterio/openmeter/pkg/pagination/v2" ) @@ -57,6 +58,21 @@ func (l *Ledger) ListTransactions(ctx context.Context, params ledger.ListTransac }, nil } +func (l *Ledger) ListTransactionsByPage(ctx context.Context, params ledger.ListTransactionsByPageInput) (pagepagination.Result[ledger.Transaction], error) { + res, err := l.repo.ListTransactionsByPage(ctx, params) + if err != nil { + return pagepagination.Result[ledger.Transaction]{}, fmt.Errorf("list transactions by page: %w", err) + } + + return pagepagination.Result[ledger.Transaction]{ + Page: res.Page, + TotalCount: res.TotalCount, + Items: lo.Map(res.Items, func(item *Transaction, _ int) ledger.Transaction { + return item + }), + }, nil +} + func (l *Ledger) GetTransactionGroup(ctx context.Context, id models.NamespacedID) (ledger.TransactionGroup, error) { group, err := l.repo.GetTransactionGroup(ctx, id) if err != nil { diff --git a/openmeter/ledger/historical/repo.go b/openmeter/ledger/historical/repo.go index 12b15732a0..750ab2b584 100644 --- a/openmeter/ledger/historical/repo.go +++ b/openmeter/ledger/historical/repo.go @@ -9,6 +9,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/pkg/framework/entutils" "github.com/openmeterio/openmeter/pkg/models" + pagepagination "github.com/openmeterio/openmeter/pkg/pagination" "github.com/openmeterio/openmeter/pkg/pagination/v2" ) @@ -29,6 +30,9 @@ type Repo interface { // List transactions with pagination ListTransactions(ctx context.Context, input ledger.ListTransactionsInput) (pagination.Result[*Transaction], error) + + // ListTransactionsByPage lists transactions using page-based pagination. + ListTransactionsByPage(ctx context.Context, input ledger.ListTransactionsByPageInput) (pagepagination.Result[*Transaction], error) } // ---------------------------------------------------------------------------- diff --git a/openmeter/ledger/historical/transaction.go b/openmeter/ledger/historical/transaction.go index c10d566259..c68a582441 100644 --- a/openmeter/ledger/historical/transaction.go +++ b/openmeter/ledger/historical/transaction.go @@ -60,6 +60,17 @@ func (t *Transaction) Annotations() models.Annotations { return t.data.Annotations } +func (t *Transaction) Cursor() ledger.TransactionCursor { + return ledger.TransactionCursor{ + BookedAt: t.data.BookedAt, + CreatedAt: t.data.CreatedAt, + ID: models.NamespacedID{ + Namespace: t.data.Namespace, + ID: t.data.ID, + }, + } +} + type TransactionGroup struct { data TransactionGroupData transactions []*Transaction diff --git a/openmeter/ledger/ledger_test.go b/openmeter/ledger/ledger_test.go index 8ed4ffafd2..8d65fad31e 100644 --- a/openmeter/ledger/ledger_test.go +++ b/openmeter/ledger/ledger_test.go @@ -102,7 +102,7 @@ func TestGetAccountBalance(t *testing.T) { balance, err := acc.GetBalance(t.Context(), ledger.RouteFilter{ Currency: currencyx.Code("USD"), - }) + }, nil) require.NoError(t, err) require.NotNil(t, balance) } diff --git a/openmeter/ledger/noop/noop.go b/openmeter/ledger/noop/noop.go index 21442249d3..57a5b3eb10 100644 --- a/openmeter/ledger/noop/noop.go +++ b/openmeter/ledger/noop/noop.go @@ -12,6 +12,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/namespace" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" + pagepagination "github.com/openmeterio/openmeter/pkg/pagination" pagination "github.com/openmeterio/openmeter/pkg/pagination/v2" ) @@ -83,6 +84,10 @@ func (Ledger) ListTransactions(context.Context, ledger.ListTransactionsInput) (p return pagination.Result[ledger.Transaction]{}, nil } +func (Ledger) ListTransactionsByPage(context.Context, ledger.ListTransactionsByPageInput) (pagepagination.Result[ledger.Transaction], error) { + return pagepagination.Result[ledger.Transaction]{}, nil +} + func (Ledger) SumEntries(context.Context, ledger.Query) (ledger.QuerySummedResult, error) { return ledger.QuerySummedResult{}, nil } @@ -155,7 +160,7 @@ type customerAccount struct { accountType ledger.AccountType } -func (customerAccount) GetBalance(context.Context, ledger.RouteFilter) (ledger.Balance, error) { +func (customerAccount) GetBalance(context.Context, ledger.RouteFilter, *ledger.TransactionCursor) (ledger.Balance, error) { return balance{}, nil } @@ -191,7 +196,7 @@ type businessAccount struct { accountType ledger.AccountType } -func (businessAccount) GetBalance(context.Context, ledger.RouteFilter) (ledger.Balance, error) { +func (businessAccount) GetBalance(context.Context, ledger.RouteFilter, *ledger.TransactionCursor) (ledger.Balance, error) { return balance{}, nil } diff --git a/openmeter/ledger/primitives.go b/openmeter/ledger/primitives.go index 12c952354f..98b81b2508 100644 --- a/openmeter/ledger/primitives.go +++ b/openmeter/ledger/primitives.go @@ -2,6 +2,8 @@ package ledger import ( "context" + "errors" + "fmt" "time" "github.com/alpacahq/alpacadecimal" @@ -9,6 +11,7 @@ import ( "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" + pagepagination "github.com/openmeterio/openmeter/pkg/pagination" "github.com/openmeterio/openmeter/pkg/pagination/v2" ) @@ -65,7 +68,7 @@ type RouteFilter struct { // Accounts describe ownership and purpose while SubAccounts parameterize the actual posting address. type Account interface { // Balance can be queried across sub-accounts according to RouteFilter - GetBalance(ctx context.Context, query RouteFilter) (Balance, error) + GetBalance(ctx context.Context, query RouteFilter, after *TransactionCursor) (Balance, error) } // ---------------------------------------------------------------------------- @@ -91,12 +94,37 @@ type TransactionInput interface { // Transaction represents a list of entries booked at the same time type Transaction interface { + Cursor() TransactionCursor BookedAt() time.Time Entries() []Entry ID() models.NamespacedID Annotations() models.Annotations } +type TransactionCursor struct { + BookedAt time.Time + CreatedAt time.Time + ID models.NamespacedID +} + +func (c TransactionCursor) Validate() error { + var errs []error + + if c.BookedAt.IsZero() { + errs = append(errs, fmt.Errorf("booked_at is zero")) + } + + if c.CreatedAt.IsZero() { + errs = append(errs, fmt.Errorf("created_at is zero")) + } + + if err := c.ID.Validate(); err != nil { + errs = append(errs, fmt.Errorf("id is invalid: %w", err)) + } + + return errors.Join(errs...) +} + type TransactionGroupInput interface { Namespace() string Transactions() []TransactionInput @@ -125,6 +153,9 @@ type Ledger interface { // // TODO: Cursoring gets problematic due to diff between wall_clock and booked_at. It would be convenient to return in order of booked_at as that simplifies parsing. This API will likely change. ListTransactions(ctx context.Context, params ListTransactionsInput) (pagination.Result[Transaction], error) + + // ListTransactionsByPage lists transactions using page-based pagination. + ListTransactionsByPage(ctx context.Context, params ListTransactionsByPageInput) (pagepagination.Result[Transaction], error) } type ListTransactionsInput struct { @@ -133,6 +164,12 @@ type ListTransactionsInput struct { Limit int TransactionID *models.NamespacedID + + // AccountIDs scopes the query to transactions with entries on these accounts. + AccountIDs []string + + // AnnotationFilters matches transactions whose annotations contain all the given key-value pairs. + AnnotationFilters map[string]string } func (i ListTransactionsInput) Validate() error { @@ -155,3 +192,22 @@ func (i ListTransactionsInput) Validate() error { return nil } + +type ListTransactionsCreditMovement uint8 + +const ( + ListTransactionsCreditMovementUnspecified ListTransactionsCreditMovement = iota + ListTransactionsCreditMovementPositive + ListTransactionsCreditMovementNegative +) + +type ListTransactionsByPageInput struct { + pagepagination.Page + + Namespace string + AccountIDs []string + Currency *currencyx.Code + + CreditMovement ListTransactionsCreditMovement + AnnotationFilters map[string]string +} diff --git a/openmeter/ledger/query.go b/openmeter/ledger/query.go index 4ce75845f9..794b202e11 100644 --- a/openmeter/ledger/query.go +++ b/openmeter/ledger/query.go @@ -68,6 +68,16 @@ func (p Query) Validate() error { } } + if p.Filters.After != nil { + if err := p.Filters.After.Validate(); err != nil { + return ErrLedgerQueryInvalid.WithAttrs(models.Attributes{ + "reason": "after_invalid", + "after": p.Filters.After, + "error": err, + }) + } + } + if _, err := p.Filters.Route.Normalize(); err != nil { return ErrLedgerQueryInvalid.WithAttrs(models.Attributes{ "reason": "route_invalid", @@ -82,6 +92,7 @@ func (p Query) Validate() error { type Filters struct { // BookedAtPeriod is inclusive-exclusive... should it be? Maybe finally add period inclusivity params? BookedAtPeriod *timeutil.OpenPeriod + After *TransactionCursor TransactionID *string // AccountID narrows the query to a single account via its sub-accounts. AccountID *string diff --git a/openmeter/ledger/transactions/correction_test.go b/openmeter/ledger/transactions/correction_test.go index 8d394fadfb..92b96c633c 100644 --- a/openmeter/ledger/transactions/correction_test.go +++ b/openmeter/ledger/transactions/correction_test.go @@ -35,6 +35,14 @@ func (t *correctionTestTransaction) Annotations() models.Annotations { return t.annotations } +func (t *correctionTestTransaction) Cursor() ledger.TransactionCursor { + return ledger.TransactionCursor{ + BookedAt: t.bookedAt, + CreatedAt: t.bookedAt, + ID: t.id, + } +} + func TestCorrectTransactionRejectsCorrectionDirection(t *testing.T) { t.Parallel() diff --git a/openmeter/server/router/router.go b/openmeter/server/router/router.go index f1f7dbd225..09d36b91fe 100644 --- a/openmeter/server/router/router.go +++ b/openmeter/server/router/router.go @@ -39,6 +39,7 @@ import ( infohttpdriver "github.com/openmeterio/openmeter/openmeter/info/httpdriver" "github.com/openmeterio/openmeter/openmeter/ingest" ingesthttpdriver "github.com/openmeterio/openmeter/openmeter/ingest/httpdriver" + "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/openmeter/llmcost" "github.com/openmeterio/openmeter/openmeter/meter" @@ -104,6 +105,8 @@ type Config struct { CurrencyService currencies.CurrencyService CostService cost.Service CreditGrantService creditgrant.Service + Ledger ledger.Ledger + AccountResolver ledger.AccountResolver Customer customer.Service CustomerBalanceFacade *customerbalance.Facade DebugConnector debug.DebugConnector diff --git a/openmeter/server/server.go b/openmeter/server/server.go index 2fbdf042c6..3842ffb1a4 100644 --- a/openmeter/server/server.go +++ b/openmeter/server/server.go @@ -114,6 +114,8 @@ func NewServer(config *Config) (*Server, error) { BillingService: config.RouterConfig.Billing, CustomerService: config.RouterConfig.Customer, CreditGrantService: config.RouterConfig.CreditGrantService, + Ledger: config.RouterConfig.Ledger, + AccountResolver: config.RouterConfig.AccountResolver, CustomerBalanceFacade: config.RouterConfig.CustomerBalanceFacade, CurrencyService: config.RouterConfig.CurrencyService, EntitlementService: config.RouterConfig.EntitlementConnector, diff --git a/test/credits/sanity_test.go b/test/credits/sanity_test.go index 08d08308ce..92769ab1b0 100644 --- a/test/credits/sanity_test.go +++ b/test/credits/sanity_test.go @@ -1258,7 +1258,7 @@ func (s *CreditsTestSuite) mustCustomerFBOBalanceWithPriority(customerID custome Currency: code, CostBasis: costBasis, CreditPriority: lo.ToPtr(priority), - }) + }, nil) s.NoError(err) return balance.Settled() @@ -1277,7 +1277,7 @@ func (s *CreditsTestSuite) mustCustomerReceivableBalance(customerID customer.Cus Currency: code, CostBasis: costBasis, TransactionAuthorizationStatus: lo.ToPtr(status), - }) + }, nil) s.NoError(err) return balance.Settled() @@ -1295,7 +1295,7 @@ func (s *CreditsTestSuite) mustCustomerAccruedBalance(customerID customer.Custom balance, err := customerAccounts.AccruedAccount.GetBalance(s.T().Context(), ledger.RouteFilter{ Currency: code, CostBasis: costBasis, - }) + }, nil) s.NoError(err) return balance.Settled() @@ -1313,7 +1313,7 @@ func (s *CreditsTestSuite) mustWashBalance(namespace string, code currencyx.Code balance, err := businessAccounts.WashAccount.GetBalance(s.T().Context(), ledger.RouteFilter{ Currency: code, CostBasis: costBasis, - }) + }, nil) s.NoError(err) return balance.Settled() @@ -1327,7 +1327,7 @@ func (s *CreditsTestSuite) mustEarningsBalance(namespace string, code currencyx. balance, err := businessAccounts.EarningsAccount.GetBalance(s.T().Context(), ledger.RouteFilter{ Currency: code, - }) + }, nil) s.NoError(err) return balance.Settled() From 5f0eed5a90e0909ce57fbe822b63abd1abe880d8 Mon Sep 17 00:00:00 2001 From: Alex Goth <64845621+GAlexIHU@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:09:26 +0200 Subject: [PATCH 2/3] refactor(ledger): move logic to facade --- api/v3/handlers/customers/credits/convert.go | 102 +++++++ api/v3/handlers/customers/credits/handler.go | 1 + .../customers/credits/list_transactions.go | 276 +----------------- .../credits/list_transactions_test.go | 193 ++---------- api/v3/server/routes.go | 10 +- api/v3/server/server.go | 36 ++- app/common/customerbalance.go | 14 +- cmd/server/wire_gen.go | 2 +- openmeter/billing/creditgrant/noop.go | 28 ++ openmeter/ledger/chargeadapter/annotations.go | 12 +- openmeter/ledger/customerbalance/facade.go | 13 + openmeter/ledger/customerbalance/noop.go | 4 + openmeter/ledger/customerbalance/service.go | 7 + .../ledger/customerbalance/testenv_test.go | 1 + .../ledger/customerbalance/transactions.go | 261 +++++++++++++++++ .../customerbalance/transactions_test.go | 147 ++++++++++ openmeter/ledger/historical/ledger.go | 4 + openmeter/ledger/primitives.go | 12 + 18 files changed, 663 insertions(+), 460 deletions(-) create mode 100644 openmeter/billing/creditgrant/noop.go create mode 100644 openmeter/ledger/customerbalance/transactions.go create mode 100644 openmeter/ledger/customerbalance/transactions_test.go diff --git a/api/v3/handlers/customers/credits/convert.go b/api/v3/handlers/customers/credits/convert.go index 702badccd4..e987e04f17 100644 --- a/api/v3/handlers/customers/credits/convert.go +++ b/api/v3/handlers/customers/credits/convert.go @@ -13,6 +13,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/billing/charges/models/payment" "github.com/openmeterio/openmeter/openmeter/billing/creditgrant" "github.com/openmeterio/openmeter/openmeter/ledger" + "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/openmeter/productcatalog" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" @@ -307,3 +308,104 @@ func convertBalance(currency currencyx.Code, balance ledger.Balance) api.CreditB Pending: balance.Pending().String(), } } + +func fromAPIBillingCreditTransactionType(filter *api.BillingCreditTransactionType) *customerbalance.CreditTransactionType { + if filter == nil { + return nil + } + + var txType customerbalance.CreditTransactionType + switch *filter { + case api.BillingCreditTransactionTypeFunded: + txType = customerbalance.CreditTransactionTypeFunded + case api.BillingCreditTransactionTypeConsumed: + txType = customerbalance.CreditTransactionTypeConsumed + case api.BillingCreditTransactionTypeAdjusted: + txType = customerbalance.CreditTransactionTypeAdjusted + default: + return nil + } + + return &txType +} + +func toAPIBillingCreditTransactions(items []customerbalance.CreditTransaction) []api.BillingCreditTransaction { + out := make([]api.BillingCreditTransaction, 0, len(items)) + + for _, item := range items { + out = append(out, toAPIBillingCreditTransaction(item)) + } + + return out +} + +func toAPIBillingCreditTransaction(tx customerbalance.CreditTransaction) api.BillingCreditTransaction { + apiTx := api.BillingCreditTransaction{ + Id: tx.ID.ID, + CreatedAt: &tx.CreatedAt, + BookedAt: tx.BookedAt, + Type: toAPIBillingCreditTransactionType(tx.Type), + Currency: api.BillingCurrencyCode(tx.Currency), + Amount: tx.Amount.String(), + Name: tx.Name, + AvailableBalance: struct { + After api.Numeric `json:"after"` + Before api.Numeric `json:"before"` + }{ + Before: api.Numeric(tx.Balance.Before.String()), + After: api.Numeric(tx.Balance.After.String()), + }, + } + + labels := creditTransactionLabels(tx.Annotations) + if len(labels) > 0 { + apiLabels := api.Labels(labels) + apiTx.Labels = &apiLabels + } + + return apiTx +} + +func toAPIBillingCreditTransactionType(txType customerbalance.CreditTransactionType) api.BillingCreditTransactionType { + switch txType { + case customerbalance.CreditTransactionTypeFunded: + return api.BillingCreditTransactionTypeFunded + case customerbalance.CreditTransactionTypeConsumed: + return api.BillingCreditTransactionTypeConsumed + default: + return api.BillingCreditTransactionTypeAdjusted + } +} + +func creditTransactionLabels(annotations models.Annotations) map[string]string { + labels := make(map[string]string) + + setLabel := func(key, annotationKey string) { + value := stringAnnotation(annotations, annotationKey) + if value != "" { + labels[key] = value + } + } + + setLabel("charge_id", ledger.AnnotationChargeID) + setLabel("subscription_id", ledger.AnnotationSubscriptionID) + setLabel("subscription_phase_id", ledger.AnnotationSubscriptionPhaseID) + setLabel("subscription_item_id", ledger.AnnotationSubscriptionItemID) + setLabel("feature_id", ledger.AnnotationFeatureID) + + return labels +} + +func stringAnnotation(annotations models.Annotations, key string) string { + raw, ok := annotations[key] + if !ok { + return "" + } + + value, ok := raw.(string) + if !ok { + return "" + } + + return value +} diff --git a/api/v3/handlers/customers/credits/handler.go b/api/v3/handlers/customers/credits/handler.go index b478dbdb4e..befb35ec8c 100644 --- a/api/v3/handlers/customers/credits/handler.go +++ b/api/v3/handlers/customers/credits/handler.go @@ -15,6 +15,7 @@ import ( type customerBalanceFacade interface { GetBalance(ctx context.Context, input customerbalance.GetBalanceInput) (alpacadecimal.Decimal, error) GetBalances(ctx context.Context, input customerbalance.GetBalancesInput) ([]customerbalance.BalanceByCurrency, error) + ListCreditTransactions(ctx context.Context, input customerbalance.ListCreditTransactionsInput) (customerbalance.ListCreditTransactionsResult, error) } type Handler interface { diff --git a/api/v3/handlers/customers/credits/list_transactions.go b/api/v3/handlers/customers/credits/list_transactions.go index 763c9d8f07..8f695cb2de 100644 --- a/api/v3/handlers/customers/credits/list_transactions.go +++ b/api/v3/handlers/customers/credits/list_transactions.go @@ -5,15 +5,12 @@ import ( "fmt" "net/http" - "github.com/alpacahq/alpacadecimal" "github.com/samber/lo" api "github.com/openmeterio/openmeter/api/v3" "github.com/openmeterio/openmeter/api/v3/apierrors" "github.com/openmeterio/openmeter/api/v3/response" "github.com/openmeterio/openmeter/openmeter/customer" - "github.com/openmeterio/openmeter/openmeter/ledger" - ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/framework/commonhttp" @@ -22,26 +19,13 @@ import ( ) type ( - ListCreditTransactionsRequest struct { - Namespace string - CustomerID string - Page pagination.Page - - TypeFilter *api.BillingCreditTransactionType - CurrencyFilter *currencyx.Code - } + ListCreditTransactionsRequest = customerbalance.ListCreditTransactionsInput ListCreditTransactionsResponse = response.PagePaginationResponse[api.BillingCreditTransaction] ListCreditTransactionsParams struct { CustomerID api.ULID Params api.ListCreditTransactionsParams } ListCreditTransactionsHandler httptransport.HandlerWithArgs[ListCreditTransactionsRequest, ListCreditTransactionsResponse, ListCreditTransactionsParams] - mappedCreditTransaction struct { - API api.BillingCreditTransaction - Amount alpacadecimal.Decimal - Currency currencyx.Code - Cursor ledger.TransactionCursor - } ) func (h *handler) ListCreditTransactions() ListCreditTransactionsHandler { @@ -70,69 +54,32 @@ func (h *handler) ListCreditTransactions() ListCreditTransactionsHandler { }) } - req := ListCreditTransactionsRequest{ - Namespace: ns, - CustomerID: args.CustomerID, - Page: page, + req := customerbalance.ListCreditTransactionsInput{ + CustomerID: customer.CustomerID{ + Namespace: ns, + ID: args.CustomerID, + }, + Page: page, } if args.Params.Filter != nil { - req.TypeFilter = args.Params.Filter.Type + req.Type = fromAPIBillingCreditTransactionType(args.Params.Filter.Type) if args.Params.Filter.Currency != nil { currency := currencyx.Code(*args.Params.Filter.Currency) - req.CurrencyFilter = ¤cy + req.Currency = ¤cy } } return req, nil }, func(ctx context.Context, request ListCreditTransactionsRequest) (ListCreditTransactionsResponse, error) { - creditMovement, empty := creditMovementFromTypeFilter(request.TypeFilter) - if empty { - return emptyCreditTransactionPage(request.Page), nil - } - - accountID, err := h.customerFBOAccountID(ctx, customer.CustomerID{ - Namespace: request.Namespace, - ID: request.CustomerID, - }) - if err != nil { - return ListCreditTransactionsResponse{}, fmt.Errorf("resolve customer FBO account: %w", err) - } - - if accountID == "" { - return emptyCreditTransactionPage(request.Page), nil - } - - listIn := ledger.ListTransactionsByPageInput{ - Page: request.Page, - Namespace: request.Namespace, - AccountIDs: []string{accountID}, - Currency: request.CurrencyFilter, - CreditMovement: creditMovement, - } - - result, err := h.ledger.ListTransactionsByPage(ctx, listIn) + result, err := h.balanceFacade.ListCreditTransactions(ctx, request) if err != nil { - return ListCreditTransactionsResponse{}, fmt.Errorf("list transactions: %w", err) + return ListCreditTransactionsResponse{}, fmt.Errorf("list credit transactions: %w", err) } - items, err := mapCreditTransactions(result.Items) - if err != nil { - return ListCreditTransactionsResponse{}, err - } - - if len(items) > 0 { - runningBalance, err := h.customerFBOBalance(ctx, request, items[0].Currency, &items[0].Cursor) - if err != nil { - return ListCreditTransactionsResponse{}, fmt.Errorf("get FBO balance after transaction %s: %w", items[0].Cursor.ID.ID, err) - } - - applyCreditTransactionBalances(items, runningBalance) - } - - return response.NewPagePaginationResponse(apiCreditTransactions(items), response.PageMetaPage{ + return response.NewPagePaginationResponse(toAPIBillingCreditTransactions(result.Items), response.PageMetaPage{ Size: request.Page.PageSize, Number: request.Page.PageNumber, Total: lo.ToPtr(result.TotalCount), @@ -146,202 +93,3 @@ func (h *handler) ListCreditTransactions() ListCreditTransactionsHandler { )..., ) } - -func emptyCreditTransactionPage(page pagination.Page) ListCreditTransactionsResponse { - return response.NewPagePaginationResponse([]api.BillingCreditTransaction{}, response.PageMetaPage{ - Size: page.PageSize, - Number: page.PageNumber, - Total: lo.ToPtr(0), - }) -} - -func creditMovementFromTypeFilter(filter *api.BillingCreditTransactionType) (ledger.ListTransactionsCreditMovement, bool) { - if filter == nil { - return ledger.ListTransactionsCreditMovementUnspecified, false - } - - switch *filter { - case api.BillingCreditTransactionTypeFunded: - return ledger.ListTransactionsCreditMovementPositive, false - case api.BillingCreditTransactionTypeConsumed: - return ledger.ListTransactionsCreditMovementNegative, false - case api.BillingCreditTransactionTypeAdjusted: - return ledger.ListTransactionsCreditMovementUnspecified, true - default: - return ledger.ListTransactionsCreditMovementUnspecified, false - } -} - -func (h *handler) customerFBOAccountID(ctx context.Context, customerID customer.CustomerID) (string, error) { - accounts, err := h.accountResolver.GetCustomerAccounts(ctx, customerID) - if err != nil { - return "", err - } - - return fboAccountIDFromCustomerAccounts(accounts), nil -} - -func fboAccountIDFromCustomerAccounts(accounts ledger.CustomerAccounts) string { - if fbo, ok := accounts.FBOAccount.(*ledgeraccount.CustomerFBOAccount); ok { - return fbo.ID().ID - } - - return "" -} - -func (h *handler) customerFBOBalance( - ctx context.Context, - req ListCreditTransactionsRequest, - currency currencyx.Code, - after *ledger.TransactionCursor, -) (alpacadecimal.Decimal, error) { - input := customerbalance.GetBalanceInput{ - CustomerID: customer.CustomerID{ - Namespace: req.Namespace, - ID: req.CustomerID, - }, - Currency: currency, - After: after, - } - - return h.balanceFacade.GetBalance(ctx, input) -} - -func applyCreditTransactionBalances(items []mappedCreditTransaction, after alpacadecimal.Decimal) { - runningBalance := after - - for i := range items { - items[i].API.AvailableBalance.After = runningBalance.String() - items[i].API.AvailableBalance.Before = runningBalance.Sub(items[i].Amount).String() - runningBalance = runningBalance.Sub(items[i].Amount) - } -} - -func mapCreditTransactions(txs []ledger.Transaction) ([]mappedCreditTransaction, error) { - items := make([]mappedCreditTransaction, 0, len(txs)) - - for _, tx := range txs { - item, err := mapCreditTransaction(tx) - if err != nil { - return nil, fmt.Errorf("convert ledger transaction %s: %w", tx.ID().ID, err) - } - - items = append(items, item) - } - - return items, nil -} - -func apiCreditTransactions(items []mappedCreditTransaction) []api.BillingCreditTransaction { - out := make([]api.BillingCreditTransaction, 0, len(items)) - for _, item := range items { - out = append(out, item.API) - } - - return out -} - -// mapCreditTransaction maps a ledger.Transaction to the API BillingCreditTransaction type plus its scoped FBO metadata. -func mapCreditTransaction(tx ledger.Transaction) (mappedCreditTransaction, error) { - entry, err := creditTransactionEntry(tx) - if err != nil { - return mappedCreditTransaction{}, err - } - - createdAt := tx.Cursor().CreatedAt - amount := entry.Amount() - currency := entry.PostingAddress().Route().Route().Currency - txType := creditTransactionType(amount) - - apiTx := api.BillingCreditTransaction{ - Id: tx.ID().ID, - CreatedAt: &createdAt, - BookedAt: tx.BookedAt(), - Type: txType, - Currency: api.BillingCurrencyCode(currency), - Amount: amount.String(), - Name: creditTransactionName(tx), - } - - labels := creditTransactionLabels(tx) - if len(labels) > 0 { - apiLabels := api.Labels(labels) - apiTx.Labels = &apiLabels - } - - return mappedCreditTransaction{ - API: apiTx, - Amount: amount, - Currency: currency, - Cursor: tx.Cursor(), - }, nil -} - -func creditTransactionEntry(tx ledger.Transaction) (ledger.Entry, error) { - for _, entry := range tx.Entries() { - if entry.PostingAddress().AccountType() != ledger.AccountTypeCustomerFBO { - continue - } - - return entry, nil - } - - return nil, fmt.Errorf("no customer FBO entry found in transaction %s", tx.ID().ID) -} - -// creditTransactionType determines the type based on the FBO impact sign. -// Positive = funded (balance went up), negative = consumed (balance went down). -func creditTransactionType(fboImpact alpacadecimal.Decimal) api.BillingCreditTransactionType { - if fboImpact.IsPositive() { - return api.BillingCreditTransactionTypeFunded - } - - if fboImpact.IsNegative() { - return api.BillingCreditTransactionTypeConsumed - } - - return api.BillingCreditTransactionTypeAdjusted -} - -func creditTransactionName(tx ledger.Transaction) string { - templateName, _ := ledger.TransactionTemplateNameFromAnnotations(tx.Annotations()) - if templateName != "" { - return templateName - } - - return "credit_transaction" -} - -func creditTransactionLabels(tx ledger.Transaction) map[string]string { - annotations := tx.Annotations() - labels := make(map[string]string) - - setLabel := func(key, annotationKey string) { - value := stringAnnotation(annotations, annotationKey) - if value != "" { - labels[key] = value - } - } - - setLabel("charge_id", ledger.AnnotationChargeID) - setLabel("subscription_id", ledger.AnnotationSubscriptionID) - setLabel("subscription_phase_id", ledger.AnnotationSubscriptionPhaseID) - setLabel("subscription_item_id", ledger.AnnotationSubscriptionItemID) - setLabel("feature_id", ledger.AnnotationFeatureID) - - return labels -} - -func stringAnnotation(annotations map[string]any, key string) string { - raw, ok := annotations[key] - if !ok { - return "" - } - - value, ok := raw.(string) - if !ok { - return "" - } - - return value -} diff --git a/api/v3/handlers/customers/credits/list_transactions_test.go b/api/v3/handlers/customers/credits/list_transactions_test.go index 17b280f074..68f041e874 100644 --- a/api/v3/handlers/customers/credits/list_transactions_test.go +++ b/api/v3/handlers/customers/credits/list_transactions_test.go @@ -1,7 +1,6 @@ package customerscredits import ( - "context" "testing" "time" @@ -10,186 +9,50 @@ import ( api "github.com/openmeterio/openmeter/api/v3" "github.com/openmeterio/openmeter/openmeter/ledger" - ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" - ledgerhistorical "github.com/openmeterio/openmeter/openmeter/ledger/historical" "github.com/openmeterio/openmeter/pkg/currencyx" "github.com/openmeterio/openmeter/pkg/models" ) -func TestCreditMovementFromTypeFilter_AdjustedReturnsEmpty(t *testing.T) { +func TestFromAPIBillingCreditTransactionType_Adjusted(t *testing.T) { filter := api.BillingCreditTransactionTypeAdjusted - movement, empty := creditMovementFromTypeFilter(&filter) + txType := fromAPIBillingCreditTransactionType(&filter) - require.Equal(t, ledger.ListTransactionsCreditMovementUnspecified, movement) - require.True(t, empty) + require.NotNil(t, txType) + require.Equal(t, customerbalance.CreditTransactionTypeAdjusted, *txType) } -func TestFBOAccountIDFromCustomerAccounts_ReturnsOnlyFBO(t *testing.T) { - fbo := mustCustomerFBOAccount(t, "ns", "fbo-account") - receivable := mustCustomerReceivableAccount(t, "ns", "receivable-account") - accrued := mustCustomerAccruedAccount(t, "ns", "accrued-account") +func TestToAPIBillingCreditTransaction(t *testing.T) { + createdAt := time.Date(2026, 4, 10, 9, 0, 0, 0, time.UTC) + bookedAt := createdAt.Add(time.Second) - accountID := fboAccountIDFromCustomerAccounts(ledger.CustomerAccounts{ - FBOAccount: fbo, - ReceivableAccount: receivable, - AccruedAccount: accrued, - }) - - require.Equal(t, "fbo-account", accountID) -} - -func TestCustomerFBOBalance_UsesCurrencyAndCursor(t *testing.T) { - usd := currencyx.Code("USD") - cursor := &ledger.TransactionCursor{ - BookedAt: time.Date(2026, 4, 10, 9, 0, 0, 0, time.UTC), - CreatedAt: time.Date(2026, 4, 10, 9, 0, 1, 0, time.UTC), + tx := toAPIBillingCreditTransaction(customerbalance.CreditTransaction{ ID: models.NamespacedID{ Namespace: "ns", ID: "tx-1", }, - } - facade := &capturingBalanceFacade{ - balance: alpacadecimal.NewFromInt(42), - } - h := handler{ - balanceFacade: facade, - } - - total, err := h.customerFBOBalance(t.Context(), ListCreditTransactionsRequest{ - Namespace: "ns", - CustomerID: "customer-1", - CurrencyFilter: &usd, - }, usd, cursor) - require.NoError(t, err) - require.True(t, total.Equal(alpacadecimal.NewFromInt(42))) - require.Equal(t, usd, facade.lastBalanceInput.Currency) - require.Equal(t, cursor, facade.lastBalanceInput.After) -} - -func TestMapCreditTransaction_UsesFBOEntry(t *testing.T) { - usd := currencyx.Code("USD") - tx := mustHistoricalTransaction(t, []ledgerhistorical.EntryData{ - mustEntryData(t, "entry-usd", ledger.AccountTypeCustomerFBO, usd, alpacadecimal.NewFromInt(-10)), - mustEntryData(t, "entry-accrued", ledger.AccountTypeCustomerAccrued, usd, alpacadecimal.NewFromInt(10)), - }) - - item, err := mapCreditTransaction(tx) - require.NoError(t, err) - require.Equal(t, api.BillingCreditTransactionTypeConsumed, item.API.Type) - require.Equal(t, api.BillingCurrencyCode("USD"), item.API.Currency) - require.Equal(t, api.Numeric("-10"), item.API.Amount) - require.True(t, item.Amount.Equal(alpacadecimal.NewFromInt(-10))) -} - -func TestApplyCreditTransactionBalances(t *testing.T) { - items := []mappedCreditTransaction{ - { - API: api.BillingCreditTransaction{ - Amount: api.Numeric("-10"), - }, - Amount: alpacadecimal.NewFromInt(-10), + CreatedAt: createdAt, + BookedAt: bookedAt, + Type: customerbalance.CreditTransactionTypeConsumed, + Currency: currencyx.Code("USD"), + Amount: alpacadecimal.NewFromInt(-10), + Balance: customerbalance.CreditTransactionBalance{ + Before: alpacadecimal.NewFromInt(52), + After: alpacadecimal.NewFromInt(42), }, - } - - applyCreditTransactionBalances(items, alpacadecimal.NewFromInt(42)) - - require.Equal(t, api.Numeric("42"), items[0].API.AvailableBalance.After) - require.Equal(t, api.Numeric("52"), items[0].API.AvailableBalance.Before) -} - -type capturingBalanceFacade struct { - lastBalanceInput customerbalance.GetBalanceInput - balance alpacadecimal.Decimal -} - -func (c *capturingBalanceFacade) GetBalance(_ context.Context, input customerbalance.GetBalanceInput) (alpacadecimal.Decimal, error) { - c.lastBalanceInput = input - return c.balance, nil -} - -func (c *capturingBalanceFacade) GetBalances(_ context.Context, _ customerbalance.GetBalancesInput) ([]customerbalance.BalanceByCurrency, error) { - return nil, nil -} - -func mustCustomerFBOAccount(t *testing.T, namespace, id string) *ledgeraccount.CustomerFBOAccount { - t.Helper() - - account := mustAccount(t, namespace, id, ledger.AccountTypeCustomerFBO) - fbo, err := account.AsCustomerFBOAccount() - require.NoError(t, err) - - return fbo -} - -func mustCustomerReceivableAccount(t *testing.T, namespace, id string) *ledgeraccount.CustomerReceivableAccount { - t.Helper() - - account := mustAccount(t, namespace, id, ledger.AccountTypeCustomerReceivable) - receivable, err := account.AsCustomerReceivableAccount() - require.NoError(t, err) - - return receivable -} - -func mustCustomerAccruedAccount(t *testing.T, namespace, id string) *ledgeraccount.CustomerAccruedAccount { - t.Helper() - - account := mustAccount(t, namespace, id, ledger.AccountTypeCustomerAccrued) - accrued, err := account.AsCustomerAccruedAccount() - require.NoError(t, err) - - return accrued -} - -func mustAccount(t *testing.T, namespace, id string, accountType ledger.AccountType) *ledgeraccount.Account { - t.Helper() - - account, err := ledgeraccount.NewAccountFromData(ledgeraccount.AccountData{ - ID: models.NamespacedID{ - Namespace: namespace, - ID: id, + Name: "credit_transaction", + Annotations: models.Annotations{ + ledger.AnnotationChargeID: "charge-1", }, - AccountType: accountType, - }, ledgeraccount.AccountLiveServices{}) - require.NoError(t, err) - - return account -} - -func mustHistoricalTransaction(t *testing.T, entries []ledgerhistorical.EntryData) ledger.Transaction { - t.Helper() - - tx, err := ledgerhistorical.NewTransactionFromData(ledgerhistorical.TransactionData{ - ID: "tx-1", - Namespace: "ns", - CreatedAt: time.Now().UTC(), - BookedAt: time.Now().UTC(), - }, entries) - require.NoError(t, err) - - return tx -} - -func mustEntryData(t *testing.T, id string, accountType ledger.AccountType, currency currencyx.Code, amount alpacadecimal.Decimal) ledgerhistorical.EntryData { - t.Helper() - - route := ledger.Route{Currency: currency} - key, err := ledger.BuildRoutingKey(ledger.RoutingKeyVersionV1, route) - require.NoError(t, err) + }) - return ledgerhistorical.EntryData{ - ID: id, - Namespace: "ns", - CreatedAt: time.Now().UTC(), - SubAccountID: id + "-subaccount", - AccountType: accountType, - Route: route, - RouteID: id + "-route", - RouteKey: key.Value(), - RouteKeyVer: key.Version(), - Amount: amount, - TransactionID: "tx-1", - } + require.Equal(t, api.ULID("tx-1"), tx.Id) + require.Equal(t, api.BillingCreditTransactionTypeConsumed, tx.Type) + require.Equal(t, api.BillingCurrencyCode("USD"), tx.Currency) + require.Equal(t, api.Numeric("-10"), tx.Amount) + require.Equal(t, api.Numeric("52"), tx.AvailableBalance.Before) + require.Equal(t, api.Numeric("42"), tx.AvailableBalance.After) + require.NotNil(t, tx.Labels) + require.Equal(t, "charge-1", (*tx.Labels)["charge_id"]) } diff --git a/api/v3/server/routes.go b/api/v3/server/routes.go index 85c7514ad2..372ce1de12 100644 --- a/api/v3/server/routes.go +++ b/api/v3/server/routes.go @@ -323,7 +323,7 @@ func (s *Server) DeletePlanAddon(w http.ResponseWriter, r *http.Request, planId var unimplemented = api.Unimplemented{} func (s *Server) GetCustomerCreditBalance(w http.ResponseWriter, r *http.Request, customerId api.ULID, params api.GetCustomerCreditBalanceParams) { - if s.customersCreditsHandler == nil { + if !s.Credits.Enabled || s.customersCreditsHandler == nil { unimplemented.GetCustomerCreditBalance(w, r, customerId, params) return } @@ -335,7 +335,7 @@ func (s *Server) GetCustomerCreditBalance(w http.ResponseWriter, r *http.Request } func (s *Server) ListCreditGrants(w http.ResponseWriter, r *http.Request, customerId api.ULID, params api.ListCreditGrantsParams) { - if s.customersCreditsHandler == nil || s.CreditGrantService == nil { + if !s.Credits.Enabled || s.customersCreditsHandler == nil || s.CreditGrantService == nil { unimplemented.ListCreditGrants(w, r, customerId, params) return } @@ -347,7 +347,7 @@ func (s *Server) ListCreditGrants(w http.ResponseWriter, r *http.Request, custom } func (s *Server) CreateCreditGrant(w http.ResponseWriter, r *http.Request, customerId api.ULID) { - if s.customersCreditsHandler == nil || s.CreditGrantService == nil { + if !s.Credits.Enabled || s.customersCreditsHandler == nil || s.CreditGrantService == nil { unimplemented.CreateCreditGrant(w, r, customerId) return } @@ -358,7 +358,7 @@ func (s *Server) CreateCreditGrant(w http.ResponseWriter, r *http.Request, custo } func (s *Server) GetCreditGrant(w http.ResponseWriter, r *http.Request, customerId api.ULID, creditGrantId api.ULID) { - if s.customersCreditsHandler == nil || s.CreditGrantService == nil { + if !s.Credits.Enabled || s.customersCreditsHandler == nil || s.CreditGrantService == nil { unimplemented.GetCreditGrant(w, r, customerId, creditGrantId) return } @@ -378,7 +378,7 @@ func (s *Server) UpdateCreditGrantExternalSettlement(w http.ResponseWriter, r *h } func (s *Server) ListCreditTransactions(w http.ResponseWriter, r *http.Request, customerId api.ULID, params api.ListCreditTransactionsParams) { - if s.customersCreditsHandler == nil || s.Ledger == nil { + if !s.Credits.Enabled || s.customersCreditsHandler == nil || s.Ledger == nil { unimplemented.ListCreditTransactions(w, r, customerId, params) return } diff --git a/api/v3/server/server.go b/api/v3/server/server.go index b4131e97b4..f4f98fdcbc 100644 --- a/api/v3/server/server.go +++ b/api/v3/server/server.go @@ -42,6 +42,7 @@ import ( "github.com/openmeterio/openmeter/openmeter/ingest" "github.com/openmeterio/openmeter/openmeter/ledger" "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" + ledgernoop "github.com/openmeterio/openmeter/openmeter/ledger/noop" "github.com/openmeterio/openmeter/openmeter/llmcost" "github.com/openmeterio/openmeter/openmeter/meter" "github.com/openmeterio/openmeter/openmeter/namespace/namespacedriver" @@ -158,6 +159,24 @@ func (c *Config) Validate() error { errs = append(errs, errors.New("feature connector is required")) } + if c.Credits.Enabled { + if c.CustomerBalanceFacade == nil { + errs = append(errs, errors.New("customer balance facade is required when credits are enabled")) + } + + if c.CreditGrantService == nil { + errs = append(errs, errors.New("credit grant service is required when credits are enabled")) + } + + if c.Ledger == nil { + errs = append(errs, errors.New("ledger is required when credits are enabled")) + } + + if c.AccountResolver == nil { + errs = append(errs, errors.New("account resolver is required when credits are enabled")) + } + } + return errors.Join(errs...) } @@ -218,10 +237,21 @@ func NewServer(config *Config) (*Server, error) { eventsHandler := eventshandler.New(resolveNamespace, config.IngestService, httptransport.WithErrorHandler(config.ErrorHandler)) customersHandler := customershandler.New(resolveNamespace, config.CustomerService, httptransport.WithErrorHandler(config.ErrorHandler)) customersBillingHandler := customersbillinghandler.New(resolveNamespace, config.BillingService, config.CustomerService, config.StripeService, httptransport.WithErrorHandler(config.ErrorHandler)) - var customersCreditsHandler customerscreditshandler.Handler - if config.CustomerBalanceFacade != nil && config.Credits.Enabled { - customersCreditsHandler = customerscreditshandler.New(resolveNamespace, config.CustomerService, config.CustomerBalanceFacade, config.CreditGrantService, config.Ledger, config.AccountResolver, httptransport.WithErrorHandler(config.ErrorHandler)) + customerBalanceFacade := config.CustomerBalanceFacade + creditGrantService := config.CreditGrantService + ledgerService := config.Ledger + accountResolver := config.AccountResolver + if !config.Credits.Enabled { + customerBalanceFacade, err = customerbalance.NewFacade(customerbalance.NewNoopService()) + if err != nil { + return nil, fmt.Errorf("create noop customer balance facade: %w", err) + } + + creditGrantService = creditgrant.NewNoopService() + ledgerService = ledgernoop.Ledger{} + accountResolver = ledgernoop.AccountResolver{} } + customersCreditsHandler := customerscreditshandler.New(resolveNamespace, config.CustomerService, customerBalanceFacade, creditGrantService, ledgerService, accountResolver, httptransport.WithErrorHandler(config.ErrorHandler)) customersEntitlementHandler := customersentitlementhandler.New(resolveNamespace, config.CustomerService, config.EntitlementService, httptransport.WithErrorHandler(config.ErrorHandler)) metersHandler := metershandler.New(resolveNamespace, config.MeterService, config.StreamingConnector, config.CustomerService, httptransport.WithErrorHandler(config.ErrorHandler)) subscriptionsHandler := subscriptionshandler.New(resolveNamespace, config.CustomerService, config.PlanService, config.PlanSubscriptionService, config.SubscriptionService, httptransport.WithErrorHandler(config.ErrorHandler)) diff --git a/app/common/customerbalance.go b/app/common/customerbalance.go index 0558a5d079..fc1942885e 100644 --- a/app/common/customerbalance.go +++ b/app/common/customerbalance.go @@ -1,19 +1,12 @@ package common import ( - "log/slog" - "github.com/google/wire" "github.com/openmeterio/openmeter/app/config" - "github.com/openmeterio/openmeter/openmeter/billing/rating" - entdb "github.com/openmeterio/openmeter/openmeter/ent/db" "github.com/openmeterio/openmeter/openmeter/ledger" ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" "github.com/openmeterio/openmeter/openmeter/ledger/customerbalance" - "github.com/openmeterio/openmeter/openmeter/productcatalog/feature" - "github.com/openmeterio/openmeter/openmeter/streaming" - "github.com/openmeterio/openmeter/pkg/framework/lockr" ) var CustomerBalance = wire.NewSet( @@ -23,16 +16,10 @@ var CustomerBalance = wire.NewSet( func NewCustomerBalanceService( creditsConfig config.CreditsConfiguration, - logger *slog.Logger, - db *entdb.Client, - locker *lockr.Locker, historicalLedger ledger.Ledger, accountResolver ledger.AccountResolver, accountService ledgeraccount.Service, billingRegistry BillingRegistry, - featureConnector feature.FeatureConnector, - ratingService rating.Service, - streamingConnector streaming.Connector, ) (customerbalance.FacadeService, error) { if !creditsConfig.Enabled { return customerbalance.NewNoopService(), nil @@ -43,6 +30,7 @@ func NewCustomerBalanceService( SubAccountService: accountService, ChargesService: billingRegistry.Charges.Service, UsageBasedService: billingRegistry.Charges.UsageBasedService, + Ledger: historicalLedger, }) } diff --git a/cmd/server/wire_gen.go b/cmd/server/wire_gen.go index 12ac68e152..2aa57b725b 100644 --- a/cmd/server/wire_gen.go +++ b/cmd/server/wire_gen.go @@ -487,7 +487,7 @@ func initializeApplication(ctx context.Context, conf config.Configuration) (Appl cleanup() return Application{}, nil, err } - facadeService, err := common.NewCustomerBalanceService(creditsConfiguration, logger, client, locker, ledger, accountResolver, accountService, billingRegistry, featureConnector, ratingService, connector) + facadeService, err := common.NewCustomerBalanceService(creditsConfiguration, ledger, accountResolver, accountService, billingRegistry) if err != nil { cleanup7() cleanup6() diff --git a/openmeter/billing/creditgrant/noop.go b/openmeter/billing/creditgrant/noop.go new file mode 100644 index 0000000000..99ce2e9462 --- /dev/null +++ b/openmeter/billing/creditgrant/noop.go @@ -0,0 +1,28 @@ +package creditgrant + +import ( + "context" + + "github.com/openmeterio/openmeter/openmeter/billing/charges/creditpurchase" + "github.com/openmeterio/openmeter/pkg/pagination" +) + +type NoopService struct{} + +var _ Service = NoopService{} + +func NewNoopService() Service { + return NoopService{} +} + +func (NoopService) Create(context.Context, CreateInput) (creditpurchase.Charge, error) { + return creditpurchase.Charge{}, nil +} + +func (NoopService) Get(context.Context, GetInput) (creditpurchase.Charge, error) { + return creditpurchase.Charge{}, nil +} + +func (NoopService) List(context.Context, ListInput) (pagination.Result[creditpurchase.Charge], error) { + return pagination.Result[creditpurchase.Charge]{}, nil +} diff --git a/openmeter/ledger/chargeadapter/annotations.go b/openmeter/ledger/chargeadapter/annotations.go index 4fcb01f267..d546a8dbee 100644 --- a/openmeter/ledger/chargeadapter/annotations.go +++ b/openmeter/ledger/chargeadapter/annotations.go @@ -1,6 +1,8 @@ package chargeadapter import ( + "github.com/samber/lo" + chargecreditpurchase "github.com/openmeterio/openmeter/openmeter/billing/charges/creditpurchase" chargeflatfee "github.com/openmeterio/openmeter/openmeter/billing/charges/flatfee" "github.com/openmeterio/openmeter/openmeter/billing/charges/meta" @@ -38,7 +40,7 @@ func chargeAnnotationsForUsageBasedCharge(charge chargeusagebased.Charge) models ID: charge.ID, }, charge.Intent.Subscription, - ptrIfNotEmpty(charge.State.FeatureID), + lo.EmptyableToPtr(charge.State.FeatureID), ) } @@ -61,11 +63,3 @@ func chargeTransactionAnnotations(chargeID models.NamespacedID, subscription *me FeatureID: featureID, }) } - -func ptrIfNotEmpty(value string) *string { - if value == "" { - return nil - } - - return &value -} diff --git a/openmeter/ledger/customerbalance/facade.go b/openmeter/ledger/customerbalance/facade.go index dbbb296f0c..217a9a07ae 100644 --- a/openmeter/ledger/customerbalance/facade.go +++ b/openmeter/ledger/customerbalance/facade.go @@ -78,6 +78,7 @@ type BalanceByCurrency struct { type FacadeService interface { GetBalance(ctx context.Context, customerID customer.CustomerID, filters ledger.RouteFilter, after *ledger.TransactionCursor) (ledger.Balance, error) + ListCreditTransactions(ctx context.Context, input ListCreditTransactionsInput) (ListCreditTransactionsResult, error) getFBOCurrencies(ctx context.Context, customerID customer.CustomerID) ([]currencyx.Code, error) } @@ -155,6 +156,18 @@ func (f *Facade) GetBalance(ctx context.Context, input GetBalanceInput) (alpacad return balance.Settled(), nil } +func (f *Facade) ListCreditTransactions(ctx context.Context, input ListCreditTransactionsInput) (ListCreditTransactionsResult, error) { + if f == nil { + return ListCreditTransactionsResult{}, errors.New("facade is required") + } + + if err := input.Validate(); err != nil { + return ListCreditTransactionsResult{}, err + } + + return f.service.ListCreditTransactions(ctx, input) +} + func routeFilter(currency currencyx.Code) ledger.RouteFilter { return ledger.RouteFilter{ Currency: currency, diff --git a/openmeter/ledger/customerbalance/noop.go b/openmeter/ledger/customerbalance/noop.go index 692c7351d6..c69b034e26 100644 --- a/openmeter/ledger/customerbalance/noop.go +++ b/openmeter/ledger/customerbalance/noop.go @@ -28,6 +28,10 @@ func (NoopService) GetBalance(context.Context, customer.CustomerID, ledger.Route return noopBalance{}, nil } +func (NoopService) ListCreditTransactions(context.Context, ListCreditTransactionsInput) (ListCreditTransactionsResult, error) { + return ListCreditTransactionsResult{}, nil +} + func (NoopService) getFBOCurrencies(context.Context, customer.CustomerID) ([]currencyx.Code, error) { return nil, nil } diff --git a/openmeter/ledger/customerbalance/service.go b/openmeter/ledger/customerbalance/service.go index 254ab25127..8bcc5e5a33 100644 --- a/openmeter/ledger/customerbalance/service.go +++ b/openmeter/ledger/customerbalance/service.go @@ -48,6 +48,7 @@ type Service struct { SubAccountService subAccountLister ChargesService chargesService UsageBasedService usageBasedTotalsService + Ledger ledger.Ledger balanceCalculator chargePendingBalanceCalculator } @@ -57,6 +58,7 @@ type Config struct { SubAccountService subAccountLister ChargesService chargesService UsageBasedService usageBasedTotalsService + Ledger ledger.Ledger } func (c Config) Validate() error { @@ -78,6 +80,10 @@ func (c Config) Validate() error { errs = append(errs, errors.New("usage based service is required")) } + if c.Ledger == nil { + errs = append(errs, errors.New("ledger is required")) + } + return errors.Join(errs...) } @@ -91,6 +97,7 @@ func New(config Config) (*Service, error) { SubAccountService: config.SubAccountService, ChargesService: config.ChargesService, UsageBasedService: config.UsageBasedService, + Ledger: config.Ledger, balanceCalculator: chargePendingBalanceCalculator{}, }, nil } diff --git a/openmeter/ledger/customerbalance/testenv_test.go b/openmeter/ledger/customerbalance/testenv_test.go index a249bb79a4..717b3003d8 100644 --- a/openmeter/ledger/customerbalance/testenv_test.go +++ b/openmeter/ledger/customerbalance/testenv_test.go @@ -211,6 +211,7 @@ func newTestEnv(t *testing.T) *testEnv { usageBasedService: usageService, }, UsageBasedService: usageService, + Ledger: base.Deps.HistoricalLedger, }) require.NoError(t, err) diff --git a/openmeter/ledger/customerbalance/transactions.go b/openmeter/ledger/customerbalance/transactions.go new file mode 100644 index 0000000000..e9dbdaf9a3 --- /dev/null +++ b/openmeter/ledger/customerbalance/transactions.go @@ -0,0 +1,261 @@ +package customerbalance + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/samber/lo" + + "github.com/openmeterio/openmeter/openmeter/customer" + "github.com/openmeterio/openmeter/openmeter/ledger" + ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/models" + pagepagination "github.com/openmeterio/openmeter/pkg/pagination" +) + +type CreditTransactionType string + +const ( + CreditTransactionTypeFunded CreditTransactionType = "funded" + CreditTransactionTypeConsumed CreditTransactionType = "consumed" + CreditTransactionTypeAdjusted CreditTransactionType = "adjusted" +) + +func (t CreditTransactionType) Validate() error { + switch t { + case CreditTransactionTypeFunded, CreditTransactionTypeConsumed, CreditTransactionTypeAdjusted: + return nil + default: + return fmt.Errorf("invalid credit transaction type: %s", t) + } +} + +type ListCreditTransactionsInput struct { + CustomerID customer.CustomerID + Page pagepagination.Page + + Type *CreditTransactionType + Currency *currencyx.Code +} + +func (i ListCreditTransactionsInput) Validate() error { + var errs []error + + if err := i.CustomerID.Validate(); err != nil { + errs = append(errs, fmt.Errorf("customer ID: %w", err)) + } + + if err := i.Page.Validate(); err != nil { + errs = append(errs, fmt.Errorf("page: %w", err)) + } + + if i.Type != nil { + if err := i.Type.Validate(); err != nil { + errs = append(errs, fmt.Errorf("type: %w", err)) + } + } + + if i.Currency != nil { + if err := i.Currency.Validate(); err != nil { + errs = append(errs, fmt.Errorf("currency: %w", err)) + } + } + + return models.NewNillableGenericValidationError(errors.Join(errs...)) +} + +type CreditTransaction struct { + ID models.NamespacedID + CreatedAt time.Time + BookedAt time.Time + Type CreditTransactionType + Currency currencyx.Code + Amount alpacadecimal.Decimal + Balance CreditTransactionBalance + Name string + Annotations models.Annotations +} + +type CreditTransactionBalance struct { + Before alpacadecimal.Decimal + After alpacadecimal.Decimal +} + +type ListCreditTransactionsResult = pagepagination.Result[CreditTransaction] + +func (s *Service) ListCreditTransactions(ctx context.Context, input ListCreditTransactionsInput) (ListCreditTransactionsResult, error) { + if err := input.Validate(); err != nil { + return ListCreditTransactionsResult{}, err + } + + creditMovement, empty, err := ledgerCreditMovement(input.Type) + if err != nil { + return ListCreditTransactionsResult{}, err + } + if empty { + return emptyCreditTransactions(input.Page), nil + } + + accountID, err := s.customerFBOAccountID(ctx, input.CustomerID) + if err != nil { + return ListCreditTransactionsResult{}, fmt.Errorf("resolve customer FBO account: %w", err) + } + if accountID == "" { + return emptyCreditTransactions(input.Page), nil + } + + result, err := s.Ledger.ListTransactionsByPage(ctx, ledger.ListTransactionsByPageInput{ + Page: input.Page, + Namespace: input.CustomerID.Namespace, + AccountIDs: []string{accountID}, + Currency: input.Currency, + CreditMovement: creditMovement, + }) + if err != nil { + return ListCreditTransactionsResult{}, fmt.Errorf("list ledger transactions: %w", err) + } + + items, err := creditTransactionsFromLedgerTransactions(result.Items) + if err != nil { + return ListCreditTransactionsResult{}, err + } + + if len(items) > 0 { + runningBalance, err := s.GetBalance(ctx, input.CustomerID, routeFilter(items[0].Currency), lo.ToPtr(result.Items[0].Cursor())) + if err != nil { + return ListCreditTransactionsResult{}, fmt.Errorf("get FBO balance after transaction %s: %w", result.Items[0].ID().ID, err) + } + + applyCreditTransactionBalances(items, runningBalance.Settled()) + } + + return ListCreditTransactionsResult{ + Page: result.Page, + TotalCount: result.TotalCount, + Items: items, + }, nil +} + +func emptyCreditTransactions(page pagepagination.Page) ListCreditTransactionsResult { + return ListCreditTransactionsResult{ + Page: page, + TotalCount: 0, + Items: []CreditTransaction{}, + } +} + +func ledgerCreditMovement(txType *CreditTransactionType) (ledger.ListTransactionsCreditMovement, bool, error) { + if txType == nil { + return ledger.ListTransactionsCreditMovementUnspecified, false, nil + } + + switch *txType { + case CreditTransactionTypeFunded: + return ledger.ListTransactionsCreditMovementPositive, false, nil + case CreditTransactionTypeConsumed: + return ledger.ListTransactionsCreditMovementNegative, false, nil + case CreditTransactionTypeAdjusted: + return ledger.ListTransactionsCreditMovementUnspecified, true, nil + default: + return ledger.ListTransactionsCreditMovementUnspecified, false, fmt.Errorf("unsupported credit transaction type: %s", *txType) + } +} + +func (s *Service) customerFBOAccountID(ctx context.Context, customerID customer.CustomerID) (string, error) { + accounts, err := s.AccountResolver.GetCustomerAccounts(ctx, customerID) + if err != nil { + return "", err + } + + return fboAccountIDFromCustomerAccounts(accounts), nil +} + +func fboAccountIDFromCustomerAccounts(accounts ledger.CustomerAccounts) string { + if fbo, ok := accounts.FBOAccount.(*ledgeraccount.CustomerFBOAccount); ok { + return fbo.ID().ID + } + + return "" +} + +func creditTransactionsFromLedgerTransactions(txs []ledger.Transaction) ([]CreditTransaction, error) { + items := make([]CreditTransaction, 0, len(txs)) + + for _, tx := range txs { + item, err := creditTransactionFromLedgerTransaction(tx) + if err != nil { + return nil, fmt.Errorf("convert ledger transaction %s: %w", tx.ID().ID, err) + } + + items = append(items, item) + } + + return items, nil +} + +func creditTransactionFromLedgerTransaction(tx ledger.Transaction) (CreditTransaction, error) { + entry, err := creditTransactionEntry(tx) + if err != nil { + return CreditTransaction{}, err + } + + amount := entry.Amount() + + return CreditTransaction{ + ID: tx.ID(), + CreatedAt: tx.Cursor().CreatedAt, + BookedAt: tx.BookedAt(), + Type: creditTransactionType(amount), + Currency: entry.PostingAddress().Route().Route().Currency, + Amount: amount, + Name: creditTransactionName(tx), + Annotations: tx.Annotations(), + }, nil +} + +func creditTransactionEntry(tx ledger.Transaction) (ledger.Entry, error) { + for _, entry := range tx.Entries() { + if entry.PostingAddress().AccountType() != ledger.AccountTypeCustomerFBO { + continue + } + + return entry, nil + } + + return nil, fmt.Errorf("no customer FBO entry found in transaction %s", tx.ID().ID) +} + +func applyCreditTransactionBalances(items []CreditTransaction, after alpacadecimal.Decimal) { + runningBalance := after + + for i := range items { + items[i].Balance.After = runningBalance + items[i].Balance.Before = runningBalance.Sub(items[i].Amount) + runningBalance = runningBalance.Sub(items[i].Amount) + } +} + +func creditTransactionType(fboImpact alpacadecimal.Decimal) CreditTransactionType { + if fboImpact.IsPositive() { + return CreditTransactionTypeFunded + } + + if fboImpact.IsNegative() { + return CreditTransactionTypeConsumed + } + + return CreditTransactionTypeAdjusted +} + +func creditTransactionName(tx ledger.Transaction) string { + templateName, _ := ledger.TransactionTemplateNameFromAnnotations(tx.Annotations()) + if templateName != "" { + return templateName + } + + return "credit_transaction" +} diff --git a/openmeter/ledger/customerbalance/transactions_test.go b/openmeter/ledger/customerbalance/transactions_test.go new file mode 100644 index 0000000000..bf2badbc90 --- /dev/null +++ b/openmeter/ledger/customerbalance/transactions_test.go @@ -0,0 +1,147 @@ +package customerbalance + +import ( + "testing" + "time" + + "github.com/alpacahq/alpacadecimal" + "github.com/stretchr/testify/require" + + "github.com/openmeterio/openmeter/openmeter/ledger" + ledgeraccount "github.com/openmeterio/openmeter/openmeter/ledger/account" + ledgerhistorical "github.com/openmeterio/openmeter/openmeter/ledger/historical" + "github.com/openmeterio/openmeter/pkg/currencyx" + "github.com/openmeterio/openmeter/pkg/models" +) + +func TestLedgerCreditMovement_AdjustedReturnsEmpty(t *testing.T) { + txType := CreditTransactionTypeAdjusted + + movement, empty, err := ledgerCreditMovement(&txType) + + require.NoError(t, err) + require.Equal(t, ledger.ListTransactionsCreditMovementUnspecified, movement) + require.True(t, empty) +} + +func TestFBOAccountIDFromCustomerAccounts_ReturnsOnlyFBO(t *testing.T) { + fbo := mustCustomerFBOAccount(t, "ns", "fbo-account") + receivable := mustCustomerReceivableAccount(t, "ns", "receivable-account") + accrued := mustCustomerAccruedAccount(t, "ns", "accrued-account") + + accountID := fboAccountIDFromCustomerAccounts(ledger.CustomerAccounts{ + FBOAccount: fbo, + ReceivableAccount: receivable, + AccruedAccount: accrued, + }) + + require.Equal(t, "fbo-account", accountID) +} + +func TestCreditTransactionFromLedgerTransaction_UsesFBOEntry(t *testing.T) { + usd := currencyx.Code("USD") + tx := mustHistoricalTransaction(t, []ledgerhistorical.EntryData{ + mustEntryData(t, "entry-usd", ledger.AccountTypeCustomerFBO, usd, alpacadecimal.NewFromInt(-10)), + mustEntryData(t, "entry-accrued", ledger.AccountTypeCustomerAccrued, usd, alpacadecimal.NewFromInt(10)), + }) + + item, err := creditTransactionFromLedgerTransaction(tx) + require.NoError(t, err) + require.Equal(t, CreditTransactionTypeConsumed, item.Type) + require.Equal(t, currencyx.Code("USD"), item.Currency) + require.True(t, item.Amount.Equal(alpacadecimal.NewFromInt(-10))) +} + +func TestApplyCreditTransactionBalances(t *testing.T) { + items := []CreditTransaction{ + { + Amount: alpacadecimal.NewFromInt(-10), + }, + } + + applyCreditTransactionBalances(items, alpacadecimal.NewFromInt(42)) + + require.True(t, items[0].Balance.After.Equal(alpacadecimal.NewFromInt(42))) + require.True(t, items[0].Balance.Before.Equal(alpacadecimal.NewFromInt(52))) +} + +func mustCustomerFBOAccount(t *testing.T, namespace, id string) *ledgeraccount.CustomerFBOAccount { + t.Helper() + + account := mustAccount(t, namespace, id, ledger.AccountTypeCustomerFBO) + fbo, err := account.AsCustomerFBOAccount() + require.NoError(t, err) + + return fbo +} + +func mustCustomerReceivableAccount(t *testing.T, namespace, id string) *ledgeraccount.CustomerReceivableAccount { + t.Helper() + + account := mustAccount(t, namespace, id, ledger.AccountTypeCustomerReceivable) + receivable, err := account.AsCustomerReceivableAccount() + require.NoError(t, err) + + return receivable +} + +func mustCustomerAccruedAccount(t *testing.T, namespace, id string) *ledgeraccount.CustomerAccruedAccount { + t.Helper() + + account := mustAccount(t, namespace, id, ledger.AccountTypeCustomerAccrued) + accrued, err := account.AsCustomerAccruedAccount() + require.NoError(t, err) + + return accrued +} + +func mustAccount(t *testing.T, namespace, id string, accountType ledger.AccountType) *ledgeraccount.Account { + t.Helper() + + account, err := ledgeraccount.NewAccountFromData(ledgeraccount.AccountData{ + ID: models.NamespacedID{ + Namespace: namespace, + ID: id, + }, + AccountType: accountType, + }, ledgeraccount.AccountLiveServices{}) + require.NoError(t, err) + + return account +} + +func mustHistoricalTransaction(t *testing.T, entries []ledgerhistorical.EntryData) ledger.Transaction { + t.Helper() + + tx, err := ledgerhistorical.NewTransactionFromData(ledgerhistorical.TransactionData{ + ID: "tx-1", + Namespace: "ns", + CreatedAt: time.Now().UTC(), + BookedAt: time.Now().UTC(), + }, entries) + require.NoError(t, err) + + return tx +} + +func mustEntryData(t *testing.T, id string, accountType ledger.AccountType, currency currencyx.Code, amount alpacadecimal.Decimal) ledgerhistorical.EntryData { + t.Helper() + + route := ledger.Route{Currency: currency} + key, err := ledger.BuildRoutingKey(ledger.RoutingKeyVersionV1, route) + require.NoError(t, err) + + return ledgerhistorical.EntryData{ + ID: id, + Namespace: "ns", + CreatedAt: time.Now().UTC(), + SubAccountID: id + "-subaccount", + AccountType: accountType, + Route: route, + RouteID: id + "-route", + RouteKey: key.Value(), + RouteKeyVer: key.Version(), + Amount: amount, + TransactionID: "tx-1", + } +} diff --git a/openmeter/ledger/historical/ledger.go b/openmeter/ledger/historical/ledger.go index 1615f8fb92..07724985ed 100644 --- a/openmeter/ledger/historical/ledger.go +++ b/openmeter/ledger/historical/ledger.go @@ -59,6 +59,10 @@ func (l *Ledger) ListTransactions(ctx context.Context, params ledger.ListTransac } func (l *Ledger) ListTransactionsByPage(ctx context.Context, params ledger.ListTransactionsByPageInput) (pagepagination.Result[ledger.Transaction], error) { + if err := params.Validate(); err != nil { + return pagepagination.Result[ledger.Transaction]{}, fmt.Errorf("failed to validate list transactions by page input: %w", err) + } + res, err := l.repo.ListTransactionsByPage(ctx, params) if err != nil { return pagepagination.Result[ledger.Transaction]{}, fmt.Errorf("list transactions by page: %w", err) diff --git a/openmeter/ledger/primitives.go b/openmeter/ledger/primitives.go index 98b81b2508..60e8fee0b7 100644 --- a/openmeter/ledger/primitives.go +++ b/openmeter/ledger/primitives.go @@ -211,3 +211,15 @@ type ListTransactionsByPageInput struct { CreditMovement ListTransactionsCreditMovement AnnotationFilters map[string]string } + +func (i ListTransactionsByPageInput) Validate() error { + if err := i.Page.Validate(); err != nil { + return ErrListTransactionsInputInvalid.WithAttrs(models.Attributes{ + "reason": "page_invalid", + "page": i.Page, + "error": err, + }) + } + + return nil +} From a9724051cf87d6fec1491c852cf0f907ca7ff2e7 Mon Sep 17 00:00:00 2001 From: Alex Goth <64845621+GAlexIHU@users.noreply.github.com> Date: Thu, 16 Apr 2026 11:34:44 +0200 Subject: [PATCH 3/3] fix: lint --- api/v3/handlers/customers/credits/convert.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/v3/handlers/customers/credits/convert.go b/api/v3/handlers/customers/credits/convert.go index e987e04f17..65ee2f82b8 100644 --- a/api/v3/handlers/customers/credits/convert.go +++ b/api/v3/handlers/customers/credits/convert.go @@ -352,8 +352,8 @@ func toAPIBillingCreditTransaction(tx customerbalance.CreditTransaction) api.Bil After api.Numeric `json:"after"` Before api.Numeric `json:"before"` }{ - Before: api.Numeric(tx.Balance.Before.String()), - After: api.Numeric(tx.Balance.After.String()), + Before: tx.Balance.Before.String(), + After: tx.Balance.After.String(), }, }