Skip to content

Commit 3c2bc06

Browse files
committed
- stripe_session_id -> stripe_invoice_id
- Add priceToPlan - Add /api/subscription - Add HandleInvoicePaid
1 parent 825c66f commit 3c2bc06

File tree

6 files changed

+175
-29
lines changed

6 files changed

+175
-29
lines changed

cmd/web/main.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ import (
1515
"github.com/stripe/stripe-go/v84"
1616
)
1717

18-
// TODO: 提供升级订阅的选择
18+
// TODO: 禁止重复订阅
19+
// TODO: 防止 /checkout/success 被滥用
20+
// TODO: 日志
21+
// TODO: 创建订阅集合迁移文件
22+
// TODO: 创建用户集合迁移文件
23+
// TODO: 用.env初始化SMTP和设置
24+
// TODO: 添加限速
25+
// TODO: 发送各种邮件
1926

2027
const version string = "v1.0.0-alpha"
2128

config/config.go

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,31 @@ type Config struct {
88
StripeKey string
99
StripeSignKey string
1010
PlanToPrice map[string]string
11+
12+
PriceToPlan map[string]string
1113
}
1214

1315
func New() *Config {
16+
// 原始映射
17+
planToPrice := map[string]string{
18+
"starter": getEnv("STRIPE_PLAN_STARTER", ""),
19+
"pro": getEnv("STRIPE_PLAN_PRO", ""),
20+
"plus": getEnv("STRIPE_PLAN_PLUS", ""),
21+
}
22+
23+
// 自动生成反向映射
24+
priceToPlan := make(map[string]string)
25+
for plan, price := range planToPrice {
26+
if price != "" {
27+
priceToPlan[price] = plan
28+
}
29+
}
30+
1431
return &Config{
1532
StripeKey: getEnv("STRIPE_KEY", ""),
1633
StripeSignKey: getEnv("STRIPE_SIGN_KEY", ""),
17-
// 动态从环境变量读取价格 ID
18-
PlanToPrice: map[string]string{
19-
"starter": os.Getenv("STRIPE_PLAN_STARTER"),
20-
"pro": os.Getenv("STRIPE_PLAN_PRO"),
21-
"plus": os.Getenv("STRIPE_PLAN_PLUS"),
22-
},
34+
PlanToPrice: planToPrice,
35+
PriceToPlan: priceToPlan,
2336
}
2437
}
2538

@@ -28,5 +41,9 @@ func getEnv(key, defaultValue string) string {
2841
return value
2942
}
3043

44+
if defaultValue == "" {
45+
panic("environment variable '" + key + "' has not been set, and is required")
46+
}
47+
3148
return defaultValue
3249
}

internal/subscriptions/handler.go

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
package subscriptions
22

33
import (
4+
"database/sql"
5+
"encoding/json"
6+
"errors"
47
"fmt"
58
"io"
69
"net/http"
710

811
"github.com/pocketbase/pocketbase/apis"
912
"github.com/pocketbase/pocketbase/core"
13+
"github.com/stripe/stripe-go/v84"
1014
"github.com/stripe/stripe-go/v84/webhook"
1115
)
1216

@@ -50,14 +54,67 @@ func (h *SubscriptionHandler) StripeWebhook(e *core.RequestEvent) error {
5054
return e.BadRequestError("Read body failed", nil)
5155
}
5256

53-
event, err := webhook.ConstructEvent(payload, e.Request.Header.Get("Stripe-Signature"), h.service.cfg.StripeSignKey)
57+
event, err := webhook.ConstructEventWithOptions(
58+
payload,
59+
e.Request.Header.Get("Stripe-Signature"),
60+
h.service.cfg.StripeSignKey,
61+
webhook.ConstructEventOptions{
62+
IgnoreAPIVersionMismatch: true, // 忽略版本不一致报错
63+
},
64+
)
5465
if err != nil {
66+
fmt.Println(err)
5567
return e.BadRequestError("Invalid signature", nil)
5668
}
5769

58-
if event.Type == "checkout.session.completed" {
59-
// ... 反序列化并调用 h.service.HandleCheckoutCompleted(sess)
70+
switch event.Type {
71+
case "checkout.session.completed":
72+
var session stripe.CheckoutSession
73+
err := json.Unmarshal(event.Data.Raw, &session)
74+
if err != nil {
75+
fmt.Println(err)
76+
return e.BadRequestError("JSON unmarshal failed", nil)
77+
}
78+
79+
// 🌟 调用 Service 层处理业务(如更新用户订阅状态、发货等)
80+
// 传入 e.App (PocketBase 实例) 以便在 Service 里操作数据库
81+
if err := h.service.HandleCheckoutCompleted(session); err != nil {
82+
fmt.Println(err)
83+
84+
return e.InternalServerError("Handle checkout failed", err)
85+
}
86+
case "invoice.paid":
87+
88+
var inv stripe.Invoice
89+
err := json.Unmarshal(event.Data.Raw, &inv)
90+
if err != nil {
91+
fmt.Println(err)
92+
return e.BadRequestError("Parsing invoice failed", err)
93+
}
94+
95+
err = h.service.HandleInvoicePaid(inv)
96+
if err != nil {
97+
fmt.Println(err)
98+
return e.InternalServerError("Handle checkout failed", err)
99+
}
100+
60101
}
61102

62103
return e.NoContent(http.StatusOK)
63104
}
105+
106+
func (h *SubscriptionHandler) CheckSubscription(e *core.RequestEvent) error {
107+
subscription, err := h.service.CheckValidSubscription(e.Auth.Original())
108+
109+
if errors.Is(err, sql.ErrNoRows) {
110+
111+
return e.BadRequestError("No subscription", nil)
112+
}
113+
114+
if err != nil {
115+
116+
return e.InternalServerError("Check subscription failed", err)
117+
}
118+
119+
return e.JSON(http.StatusOK, subscription)
120+
}

internal/subscriptions/router.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ func RegisterRoutes(app *pocketbase.PocketBase, se *core.ServeEvent, cfg *config
1515
// 路由注册
1616
se.Router.POST("/api/webhook/stripe", handler.StripeWebhook)
1717
se.Router.POST("/api/checkout/subscription", handler.Checkout).Bind(apis.RequireAuth())
18+
se.Router.GET("/api/subscription", handler.CheckSubscription).Bind(apis.RequireAuth())
1819
}

internal/subscriptions/service.go

Lines changed: 81 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
package subscriptions
22

33
import (
4+
"errors"
45
"fmt"
56
"time"
67
"website-pb/config"
78

9+
"github.com/pocketbase/dbx"
810
"github.com/pocketbase/pocketbase"
911
"github.com/pocketbase/pocketbase/core"
1012
"github.com/stripe/stripe-go/v84"
1113
"github.com/stripe/stripe-go/v84/checkout/session"
12-
"github.com/stripe/stripe-go/v84/subscription"
1314
)
1415

1516
type SubscriptionService struct {
@@ -21,6 +22,31 @@ func NewService(app *pocketbase.PocketBase, cfg *config.Config) *SubscriptionSer
2122
return &SubscriptionService{app: app, cfg: cfg}
2223
}
2324

25+
func (s *SubscriptionService) CheckValidSubscription(user *core.Record) (*core.Record, error) {
26+
27+
now := time.Now().UTC().Format("2006-01-02 15:04:05.000Z")
28+
29+
record := &core.Record{}
30+
31+
err := s.app.RecordQuery("subscriptions").
32+
// 1. 基础过滤:用户 ID
33+
AndWhere(dbx.HashExp{"user_id": user.Id}).
34+
// 2. 时间过滤:未过期
35+
AndWhere(dbx.NewExp("expires_at > {:now}", dbx.Params{"now": now})).
36+
// 3. 排序:将到期时间最远的排在最前面
37+
OrderBy("expires_at DESC").
38+
// 4. 只取一条
39+
Limit(1).
40+
// 5. 将结果映射到 record 对象
41+
One(record)
42+
43+
if err != nil {
44+
return nil, err // 没找到会返回 sql.ErrNoRows
45+
}
46+
47+
return record, nil
48+
}
49+
2450
// CreateCheckoutSession 处理 Stripe 会话创建
2551
func (s *SubscriptionService) CreateCheckoutSession(user *core.Record, plan string, frontendURL string) (string, error) {
2652
priceID, exists := s.cfg.PlanToPrice[plan]
@@ -33,8 +59,11 @@ func (s *SubscriptionService) CreateCheckoutSession(user *core.Record, plan stri
3359
CancelURL: stripe.String(frontendURL),
3460
Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
3561
ClientReferenceID: stripe.String(user.Id),
36-
Metadata: map[string]string{"plan": plan},
37-
LineItems: []*stripe.CheckoutSessionLineItemParams{{Price: stripe.String(priceID), Quantity: stripe.Int64(1)}},
62+
SubscriptionData: &stripe.CheckoutSessionSubscriptionDataParams{
63+
Metadata: map[string]string{"plan": "pro"}, // 存到订阅对象里
64+
},
65+
Metadata: map[string]string{"plan": plan},
66+
LineItems: []*stripe.CheckoutSessionLineItemParams{{Price: stripe.String(priceID), Quantity: stripe.Int64(1)}},
3867
}
3968

4069
// 关联已有 Stripe Customer ID
@@ -48,28 +77,63 @@ func (s *SubscriptionService) CreateCheckoutSession(user *core.Record, plan stri
4877
return sess.URL, err
4978
}
5079

51-
// HandleCheckoutCompleted 处理支付成功后的数据更新
52-
func (s *SubscriptionService) HandleCheckoutCompleted(sess stripe.CheckoutSession) error {
53-
// 1. 获取用户
54-
user, err := s.app.FindRecordById("users", sess.ClientReferenceID)
80+
func (s *SubscriptionService) HandleInvoicePaid(inv stripe.Invoice) error {
81+
82+
if inv.Customer == nil {
83+
fmt.Println("invoice customer is nil")
84+
return errors.New("invoice customer is nil")
85+
}
86+
87+
if len(inv.Lines.Data) == 0 {
88+
fmt.Println("invoice has no lines")
89+
90+
return errors.New("invoice has no lines")
91+
}
92+
93+
stripeCustomerID := inv.Customer.ID
94+
user, err := s.app.FindFirstRecordByFilter("users", "stripe_customer_id = {:id}", map[string]any{"id": stripeCustomerID})
5595
if err != nil {
5696
return err
5797
}
5898

59-
// 2. 更新用户信息
60-
user.Set("stripe_customer_id", sess.Customer.ID)
61-
user.Set("plan", sess.Metadata["plan"])
62-
if err := s.app.Save(user); err != nil {
63-
return err
99+
collection, err := s.app.FindCollectionByNameOrId("subscriptions")
100+
if err != nil {
101+
return errors.New("subscriptions collection not found")
64102
}
65103

66-
// 3. 创建订阅记录
67-
sub, _ := subscription.Get(sess.Subscription.ID, nil)
68-
collection, _ := s.app.FindCollectionByNameOrId("subscriptions")
69104
record := core.NewRecord(collection)
105+
106+
fmt.Println(inv.Lines.Data[0].Pricing.PriceDetails.Price)
107+
108+
priceID := inv.Lines.Data[0].Pricing.PriceDetails.Price
109+
110+
priceIDMap := s.cfg.PriceToPlan[priceID]
111+
fmt.Println(priceID)
112+
113+
if priceIDMap == "" {
114+
return errors.New("invalid price")
115+
}
116+
117+
expiresAt := time.Unix(inv.Lines.Data[0].Period.End, 0).UTC()
118+
fmt.Println("Expires at:", expiresAt)
119+
fmt.Println(inv.Lines.Data[0].Period.End)
120+
70121
record.Set("user_id", user.Id)
71-
record.Set("plan", sess.Metadata["plan"])
72-
record.Set("expires_at", time.Unix(sub.Items.Data[0].CurrentPeriodEnd, 0).UTC())
122+
record.Set("plan", priceIDMap)
123+
record.Set("stripe_invoice_id", inv.ID)
124+
record.Set("expires_at", expiresAt)
73125

74126
return s.app.Save(record)
75127
}
128+
129+
func (s *SubscriptionService) HandleCheckoutCompleted(sess stripe.CheckoutSession) error {
130+
user, err := s.app.FindRecordById("users", sess.ClientReferenceID)
131+
if err != nil {
132+
return err
133+
}
134+
135+
// 2. 更新用户信息
136+
user.Set("stripe_customer_id", sess.Customer.ID)
137+
138+
return s.app.Save(user)
139+
}

migrations/1769317614_collections_snapshot.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ func init() {
848848
"id": "text439437911",
849849
"max": 0,
850850
"min": 0,
851-
"name": "stripe_session_id",
851+
"name": "stripe_invoice_id",
852852
"pattern": "",
853853
"presentable": false,
854854
"primaryKey": false,
@@ -890,7 +890,7 @@ func init() {
890890
],
891891
"id": "pbc_4099809493",
892892
"indexes": [
893-
"CREATE UNIQUE INDEX ` + "`" + `idx_Ey9oa0Co7S` + "`" + ` ON ` + "`" + `subscriptions` + "`" + ` (` + "`" + `stripe_session_id` + "`" + `)"
893+
"CREATE UNIQUE INDEX ` + "`" + `idx_Ey9oa0Co7S` + "`" + ` ON ` + "`" + `subscriptions` + "`" + ` (` + "`" + `stripe_invoice_id` + "`" + `)"
894894
],
895895
"listRule": null,
896896
"name": "subscriptions",

0 commit comments

Comments
 (0)