diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index b6b19ca0f36..51ecefe250f 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "entgo.io/ent/dialect" + dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" @@ -429,16 +431,22 @@ func (s *PaymentService) doSub(ctx context.Context, o *dbent.PaymentOrder) error if err != nil || g.Status != payment.EntityStatusActive { return fmt.Errorf("group %d no longer exists or inactive", gid) } - // Idempotency: check audit log to see if subscription was already assigned. - // Prevents double-extension on retry after markCompleted fails. - if s.hasAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS") { + assigned := s.hasAuditLog(ctx, o.ID, "SUBSCRIPTION_ASSIGNED") || s.hasAuditLog(ctx, o.ID, "SUBSCRIPTION_SUCCESS") + if !assigned { + orderNote := fmt.Sprintf("payment order %d", o.ID) + _, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: orderNote}) + if err != nil { + return fmt.Errorf("assign subscription: %w", err) + } + s.writeAuditLog(ctx, o.ID, "SUBSCRIPTION_ASSIGNED", "system", map[string]any{ + "groupID": gid, + "validityDays": days, + }) + } else { slog.Info("subscription already assigned for order, skipping", "orderID", o.ID, "groupID", gid) - return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS") } - orderNote := fmt.Sprintf("payment order %d", o.ID) - _, _, err = s.subscriptionSvc.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{UserID: o.UserID, GroupID: gid, ValidityDays: days, AssignedBy: 0, Notes: orderNote}) - if err != nil { - return fmt.Errorf("assign subscription: %w", err) + if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil { + return err } return s.markCompleted(ctx, o, "SUBSCRIPTION_SUCCESS") } @@ -452,7 +460,8 @@ func (s *PaymentService) hasAuditLog(ctx context.Context, orderID int64, action } func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *dbent.PaymentOrder) error { - if o == nil || o.OrderType != payment.OrderTypeBalance || o.Amount <= 0 { + baseAmount := affiliateRebateBaseAmount(o) + if o == nil || baseAmount <= 0 { return nil } if s.affiliateService == nil { @@ -469,7 +478,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db defer func() { _ = tx.Rollback() }() txCtx := dbent.NewTxContext(ctx, tx) - claimed, err := s.tryClaimAffiliateRebateAudit(txCtx, tx.Client(), o.ID, o.Amount) + claimed, err := s.tryClaimAffiliateRebateAudit(txCtx, tx.Client(), o.ID, baseAmount) if err != nil { s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), @@ -481,7 +490,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db } sourceOrderID := o.ID - rebateAmount, err := s.affiliateService.AccrueInviteRebateForOrder(txCtx, o.UserID, o.Amount, &sourceOrderID) + rebateAmount, err := s.affiliateService.AccrueInviteRebateForOrder(txCtx, o.UserID, baseAmount, &sourceOrderID) if err != nil { s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ "error": err.Error(), @@ -491,7 +500,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db if rebateAmount <= 0 { if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_SKIPPED", map[string]any{ - "baseAmount": o.Amount, + "baseAmount": baseAmount, "reason": "no inviter bound or rebate amount <= 0", }); err != nil { s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ @@ -509,7 +518,7 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db } if err := s.updateClaimedAffiliateRebateAudit(txCtx, tx.Client(), o.ID, "AFFILIATE_REBATE_APPLIED", map[string]any{ - "baseAmount": o.Amount, + "baseAmount": baseAmount, "rebateAmount": rebateAmount, }); err != nil { s.writeAuditLog(ctx, o.ID, "AFFILIATE_REBATE_FAILED", "system", map[string]any{ @@ -527,6 +536,18 @@ func (s *PaymentService) applyAffiliateRebateForOrder(ctx context.Context, o *db return nil } +func affiliateRebateBaseAmount(o *dbent.PaymentOrder) float64 { + if o == nil { + return 0 + } + switch o.OrderType { + case payment.OrderTypeBalance, payment.OrderTypeSubscription: + return o.Amount + default: + return 0 + } +} + func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, baseAmount float64) (bool, error) { if client == nil { return false, errors.New("nil payment client") @@ -536,17 +557,8 @@ func (s *PaymentService) tryClaimAffiliateRebateAudit(ctx context.Context, clien "baseAmount": baseAmount, "status": "reserved", }) - rows, err := client.QueryContext(ctx, ` -INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at) -SELECT $1::text, 'AFFILIATE_REBATE_APPLIED', $2::text, 'system', NOW() -WHERE NOT EXISTS ( - SELECT 1 - FROM payment_audit_logs - WHERE order_id = $1::text - AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED') -) -ON CONFLICT (order_id, action) DO NOTHING -RETURNING id`, oid, string(detail)) + query, args := buildAffiliateRebateAuditClaimQuery(client, oid, string(detail)) + rows, err := client.QueryContext(ctx, query, args...) if err != nil { return false, err } @@ -564,6 +576,48 @@ RETURNING id`, oid, string(detail)) return true, nil } +func buildAffiliateRebateAuditClaimQuery(client *dbent.Client, orderID, detail string) (string, []any) { + nowExpr := paymentAuditCurrentTimestampExpr(client) + if paymentAuditDialect(client) == dialect.Postgres { + return fmt.Sprintf(` +INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at) +SELECT $1::text, 'AFFILIATE_REBATE_APPLIED', $2::text, 'system', %s +WHERE NOT EXISTS ( + SELECT 1 + FROM payment_audit_logs + WHERE order_id = $1::text + AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED') +) +ON CONFLICT (order_id, action) DO NOTHING +RETURNING id`, nowExpr), []any{orderID, detail} + } + return fmt.Sprintf(` +INSERT INTO payment_audit_logs (order_id, action, detail, operator, created_at) +SELECT ?, 'AFFILIATE_REBATE_APPLIED', ?, 'system', %s +WHERE NOT EXISTS ( + SELECT 1 + FROM payment_audit_logs + WHERE order_id = ? + AND action IN ('AFFILIATE_REBATE_APPLIED', 'AFFILIATE_REBATE_SKIPPED') +) +ON CONFLICT (order_id, action) DO NOTHING +RETURNING id`, nowExpr), []any{orderID, detail, orderID} +} + +func paymentAuditCurrentTimestampExpr(client *dbent.Client) string { + if paymentAuditDialect(client) == dialect.Postgres { + return "NOW()" + } + return "CURRENT_TIMESTAMP" +} + +func paymentAuditDialect(client *dbent.Client) string { + if client == nil || client.Driver() == nil { + return "" + } + return client.Driver().Dialect() +} + func (s *PaymentService) updateClaimedAffiliateRebateAudit(ctx context.Context, client *dbent.Client, orderID int64, action string, detail map[string]any) error { if client == nil { return errors.New("nil payment client") diff --git a/backend/internal/service/payment_fulfillment_test.go b/backend/internal/service/payment_fulfillment_test.go index f46cf037a1c..b46d6a1fc81 100644 --- a/backend/internal/service/payment_fulfillment_test.go +++ b/backend/internal/service/payment_fulfillment_test.go @@ -6,11 +6,15 @@ import ( "context" "errors" "math" + "strconv" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type paymentFulfillmentTestProvider struct { @@ -36,6 +40,169 @@ func (p paymentFulfillmentTestProvider) Refund(ctx context.Context, req payment. panic("unexpected call") } +type paymentFulfillmentAffiliateAccrueCall struct { + inviterID int64 + inviteeUserID int64 + amount float64 + freezeHours int + sourceOrderID *int64 +} + +type paymentFulfillmentAffiliateRepoStub struct { + inviteeSummary *AffiliateSummary + inviterSummary *AffiliateSummary + accrueCalls []paymentFulfillmentAffiliateAccrueCall +} + +func (r *paymentFulfillmentAffiliateRepoStub) EnsureUserAffiliate(_ context.Context, userID int64) (*AffiliateSummary, error) { + switch { + case r.inviteeSummary != nil && r.inviteeSummary.UserID == userID: + cp := *r.inviteeSummary + return &cp, nil + case r.inviterSummary != nil && r.inviterSummary.UserID == userID: + cp := *r.inviterSummary + return &cp, nil + default: + return &AffiliateSummary{UserID: userID, AffCode: "AFFTEST", CreatedAt: time.Now().Add(-time.Hour)}, nil + } +} + +func (r *paymentFulfillmentAffiliateRepoStub) GetAffiliateByCode(context.Context, string) (*AffiliateSummary, error) { + panic("unexpected GetAffiliateByCode call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) BindInviter(context.Context, int64, int64) (bool, error) { + panic("unexpected BindInviter call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) AccrueQuota(_ context.Context, inviterID, inviteeUserID int64, amount float64, freezeHours int, sourceOrderID *int64) (bool, error) { + var sourceCopy *int64 + if sourceOrderID != nil { + v := *sourceOrderID + sourceCopy = &v + } + r.accrueCalls = append(r.accrueCalls, paymentFulfillmentAffiliateAccrueCall{ + inviterID: inviterID, + inviteeUserID: inviteeUserID, + amount: amount, + freezeHours: freezeHours, + sourceOrderID: sourceCopy, + }) + return true, nil +} + +func (r *paymentFulfillmentAffiliateRepoStub) GetAccruedRebateFromInvitee(context.Context, int64, int64) (float64, error) { + return 0, nil +} + +func (r *paymentFulfillmentAffiliateRepoStub) ThawFrozenQuota(context.Context, int64) (float64, error) { + panic("unexpected ThawFrozenQuota call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) TransferQuotaToBalance(context.Context, int64) (float64, float64, error) { + panic("unexpected TransferQuotaToBalance call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) ListInvitees(context.Context, int64, int) ([]AffiliateInvitee, error) { + panic("unexpected ListInvitees call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) UpdateUserAffCode(context.Context, int64, string) error { + panic("unexpected UpdateUserAffCode call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) ResetUserAffCode(context.Context, int64) (string, error) { + panic("unexpected ResetUserAffCode call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) SetUserRebateRate(context.Context, int64, *float64) error { + panic("unexpected SetUserRebateRate call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) BatchSetUserRebateRate(context.Context, []int64, *float64) error { + panic("unexpected BatchSetUserRebateRate call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) ListUsersWithCustomSettings(context.Context, AffiliateAdminFilter) ([]AffiliateAdminEntry, int64, error) { + panic("unexpected ListUsersWithCustomSettings call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) ListAffiliateInviteRecords(context.Context, AffiliateRecordFilter) ([]AffiliateInviteRecord, int64, error) { + panic("unexpected ListAffiliateInviteRecords call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) ListAffiliateRebateRecords(context.Context, AffiliateRecordFilter) ([]AffiliateRebateRecord, int64, error) { + panic("unexpected ListAffiliateRebateRecords call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) ListAffiliateTransferRecords(context.Context, AffiliateRecordFilter) ([]AffiliateTransferRecord, int64, error) { + panic("unexpected ListAffiliateTransferRecords call") +} + +func (r *paymentFulfillmentAffiliateRepoStub) GetAffiliateUserOverview(context.Context, int64) (*AffiliateUserOverview, error) { + panic("unexpected GetAffiliateUserOverview call") +} + +type paymentFulfillmentSettingRepoStub struct { + values map[string]string +} + +func (s *paymentFulfillmentSettingRepoStub) Get(context.Context, string) (*Setting, error) { + return nil, ErrSettingNotFound +} + +func (s *paymentFulfillmentSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if s.values == nil { + return "", ErrSettingNotFound + } + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *paymentFulfillmentSettingRepoStub) Set(_ context.Context, key, value string) error { + if s.values == nil { + s.values = map[string]string{} + } + s.values[key] = value + return nil +} + +func (s *paymentFulfillmentSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + out[key] = s.values[key] + } + return out, nil +} + +func (s *paymentFulfillmentSettingRepoStub) SetMultiple(_ context.Context, values map[string]string) error { + if s.values == nil { + s.values = map[string]string{} + } + for key, value := range values { + s.values[key] = value + } + return nil +} + +func (s *paymentFulfillmentSettingRepoStub) GetAll(context.Context) (map[string]string, error) { + return s.values, nil +} + +func (s *paymentFulfillmentSettingRepoStub) Delete(_ context.Context, key string) error { + delete(s.values, key) + return nil +} + +func ensurePaymentAuditOrderActionUniqueIndex(t *testing.T, ctx context.Context, client *dbent.Client) { + t.Helper() + _, err := client.ExecContext(ctx, "CREATE UNIQUE INDEX IF NOT EXISTS idx_payment_audit_logs_order_action_uniq ON payment_audit_logs(order_id, action)") + require.NoError(t, err) +} + // --------------------------------------------------------------------------- // resolveRedeemAction — pure idempotency decision logic // --------------------------------------------------------------------------- @@ -418,3 +585,179 @@ func TestPaymentAmountToleranceForThreeDecimalCurrency(t *testing.T) { assert.Equal(t, amountToleranceCNY, paymentAmountToleranceForCurrency("JPY")) assert.InDelta(t, 0.0005, paymentAmountToleranceForCurrency("KWD"), 1e-12) } + +func TestExecuteSubscriptionFulfillmentAppliesAffiliateRebate(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + ensurePaymentAuditOrderActionUniqueIndex(t, ctx, client) + + user, err := client.User.Create(). + SetEmail("subscription-affiliate@example.com"). + SetPasswordHash("hash"). + SetUsername("subscription-affiliate-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(120). + SetPayAmount(120). + SetFeeRate(0). + SetRechargeCode("PAY-SUB-AFFILIATE"). + SetOutTradeNo("sub2_subscription_affiliate"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-sub-affiliate"). + SetOrderType(payment.OrderTypeSubscription). + SetPlanID(99). + SetSubscriptionGroupID(7). + SetSubscriptionDays(30). + SetStatus(OrderStatusPaid). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + inviterID := int64(9001) + affiliateRepo := &paymentFulfillmentAffiliateRepoStub{ + inviteeSummary: &AffiliateSummary{ + UserID: user.ID, + AffCode: "INVITEE", + InviterID: &inviterID, + CreatedAt: time.Now().Add(-24 * time.Hour), + }, + inviterSummary: &AffiliateSummary{ + UserID: inviterID, + AffCode: "INVITER", + CreatedAt: time.Now().Add(-48 * time.Hour), + }, + } + settingSvc := NewSettingService(&paymentFulfillmentSettingRepoStub{values: map[string]string{ + SettingKeyAffiliateEnabled: "true", + SettingKeyAffiliateRebateRate: "20", + SettingKeyAffiliateRebateFreezeHours: "0", + }}, nil) + subRepo := newSubscriptionUserSubRepoStub() + subscriptionSvc := NewSubscriptionService(&subscriptionGroupRepoStub{ + group: &Group{ID: 7, Status: payment.EntityStatusActive, SubscriptionType: SubscriptionTypeSubscription}, + }, subRepo, nil, nil, nil) + svc := &PaymentService{ + entClient: client, + groupRepo: &subscriptionGroupRepoStub{group: &Group{ID: 7, Status: payment.EntityStatusActive, SubscriptionType: SubscriptionTypeSubscription}}, + subscriptionSvc: subscriptionSvc, + affiliateService: NewAffiliateService(affiliateRepo, settingSvc, nil, nil), + } + + err = svc.ExecuteSubscriptionFulfillment(ctx, order.ID) + require.NoError(t, err) + + reloaded, err := client.PaymentOrder.Get(ctx, order.ID) + require.NoError(t, err) + require.Equal(t, OrderStatusCompleted, reloaded.Status) + require.Len(t, affiliateRepo.accrueCalls, 1) + require.Equal(t, inviterID, affiliateRepo.accrueCalls[0].inviterID) + require.Equal(t, user.ID, affiliateRepo.accrueCalls[0].inviteeUserID) + require.Equal(t, 24.0, affiliateRepo.accrueCalls[0].amount) + require.NotNil(t, affiliateRepo.accrueCalls[0].sourceOrderID) + require.Equal(t, order.ID, *affiliateRepo.accrueCalls[0].sourceOrderID) + require.Equal(t, 1, subRepo.createCalls) + + applied, err := client.PaymentAuditLog.Query(). + Where(paymentauditlog.OrderIDEQ(strconv.FormatInt(order.ID, 10)), paymentauditlog.ActionEQ("AFFILIATE_REBATE_APPLIED")). + Only(ctx) + require.NoError(t, err) + require.Contains(t, applied.Detail, `"baseAmount":120`) + require.Contains(t, applied.Detail, `"rebateAmount":24`) +} + +func TestExecuteSubscriptionFulfillmentDoesNotDuplicateWorkAfterLegacySuccessAudit(t *testing.T) { + ctx := context.Background() + client := newPaymentConfigServiceTestClient(t) + ensurePaymentAuditOrderActionUniqueIndex(t, ctx, client) + + user, err := client.User.Create(). + SetEmail("subscription-affiliate-idempotent@example.com"). + SetPasswordHash("hash"). + SetUsername("subscription-affiliate-idempotent-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(80). + SetPayAmount(80). + SetFeeRate(0). + SetRechargeCode("PAY-SUB-AFFILIATE-IDEMPOTENT"). + SetOutTradeNo("sub2_subscription_affiliate_idempotent"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo("trade-sub-affiliate-idempotent"). + SetOrderType(payment.OrderTypeSubscription). + SetPlanID(100). + SetSubscriptionGroupID(7). + SetSubscriptionDays(30). + SetStatus(OrderStatusPaid). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentAuditLog.Create(). + SetOrderID(strconv.FormatInt(order.ID, 10)). + SetAction("SUBSCRIPTION_SUCCESS"). + SetDetail(`{"groupID":7,"validityDays":30}`). + SetOperator("system"). + Save(ctx) + require.NoError(t, err) + _, err = client.PaymentAuditLog.Create(). + SetOrderID(strconv.FormatInt(order.ID, 10)). + SetAction("AFFILIATE_REBATE_APPLIED"). + SetDetail(`{"baseAmount":80,"rebateAmount":16}`). + SetOperator("system"). + Save(ctx) + require.NoError(t, err) + + inviterID := int64(9001) + affiliateRepo := &paymentFulfillmentAffiliateRepoStub{ + inviteeSummary: &AffiliateSummary{ + UserID: user.ID, + AffCode: "INVITEE", + InviterID: &inviterID, + CreatedAt: time.Now().Add(-24 * time.Hour), + }, + inviterSummary: &AffiliateSummary{ + UserID: inviterID, + AffCode: "INVITER", + CreatedAt: time.Now().Add(-48 * time.Hour), + }, + } + settingSvc := NewSettingService(&paymentFulfillmentSettingRepoStub{values: map[string]string{ + SettingKeyAffiliateEnabled: "true", + SettingKeyAffiliateRebateRate: "20", + }}, nil) + subRepo := newSubscriptionUserSubRepoStub() + subscriptionSvc := NewSubscriptionService(&subscriptionGroupRepoStub{ + group: &Group{ID: 7, Status: payment.EntityStatusActive, SubscriptionType: SubscriptionTypeSubscription}, + }, subRepo, nil, nil, nil) + svc := &PaymentService{ + entClient: client, + groupRepo: &subscriptionGroupRepoStub{group: &Group{ID: 7, Status: payment.EntityStatusActive, SubscriptionType: SubscriptionTypeSubscription}}, + subscriptionSvc: subscriptionSvc, + affiliateService: NewAffiliateService(affiliateRepo, settingSvc, nil, nil), + } + + err = svc.ExecuteSubscriptionFulfillment(ctx, order.ID) + require.NoError(t, err) + + reloaded, err := client.PaymentOrder.Get(ctx, order.ID) + require.NoError(t, err) + require.Equal(t, OrderStatusCompleted, reloaded.Status) + require.Empty(t, affiliateRepo.accrueCalls) + require.Zero(t, subRepo.createCalls) +} + +var _ AffiliateRepository = (*paymentFulfillmentAffiliateRepoStub)(nil) +var _ SettingRepository = (*paymentFulfillmentSettingRepoStub)(nil)