Skip to content

Commit 093c261

Browse files
authored
Merge pull request #38 from appbaseio/feat/plan-restriction
feat: Plan based access control
2 parents 8dc7226 + 513b891 commit 093c261

4 files changed

Lines changed: 151 additions & 5 deletions

File tree

middleware/validate/plan.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package validate
2+
3+
import (
4+
"net/http"
5+
6+
"github.com/appbaseio/arc/middleware"
7+
"github.com/appbaseio/arc/util"
8+
)
9+
10+
// Plan returns a middleware that validates the user's plan.
11+
// For e.g `validate.Plan([]util.Plan{util.ArcBasic}),` restricts the route to arc-basic users.
12+
func Plan(restrictedPlans []util.Plan) middleware.Middleware {
13+
if util.ValidatedPlans(restrictedPlans) {
14+
return validPlan
15+
}
16+
return invalidPlan
17+
}
18+
19+
// Throws the payment required error
20+
func invalidPlan(h http.HandlerFunc) http.HandlerFunc {
21+
return func(w http.ResponseWriter, req *http.Request) {
22+
msg := "This feature is not available for the " + util.Tier.String() + " plan users."
23+
util.WriteBackError(w, msg, http.StatusPaymentRequired)
24+
}
25+
}
26+
27+
// Authorize to access the request
28+
func validPlan(h http.HandlerFunc) http.HandlerFunc {
29+
return func(w http.ResponseWriter, req *http.Request) {
30+
h(w, req)
31+
}
32+
}

model/category/category.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ const (
3131
Streams
3232
Rules
3333
Templates
34+
Suggestions
3435
)
3536

3637
// String is an implementation of Stringer interface that returns the string representation of category.Categories.
@@ -48,6 +49,7 @@ func (c Category) String() string {
4849
"streams",
4950
"rules",
5051
"templates",
52+
"suggestions",
5153
}[c]
5254
}
5355

@@ -83,6 +85,8 @@ func (c *Category) UnmarshalJSON(bytes []byte) error {
8385
*c = Rules
8486
case Templates.String():
8587
*c = Templates
88+
case Suggestions.String():
89+
*c = Suggestions
8690
default:
8791
return fmt.Errorf("invalid category encountered: %v", category)
8892
}
@@ -117,6 +121,8 @@ func (c Category) MarshalJSON() ([]byte, error) {
117121
category = Rules.String()
118122
case Templates:
119123
category = Templates.String()
124+
case Suggestions:
125+
category = Suggestions.String()
120126
default:
121127
return nil, fmt.Errorf("invalid category encountered: %v" + c.String())
122128
}

util/billing.go

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ var ACCAPI = "https://accapi.appbase.io/"
2121
// TimeValidity to be obtained from ACCAPI
2222
var TimeValidity int64
2323

24+
// Tier is the value of the user's plan
25+
var Tier *Plan
26+
2427
// MaxErrorTime before showing errors if invalid trial / plan in hours
2528
var MaxErrorTime int64 = 24 // in hrs
2629

@@ -65,7 +68,7 @@ type ArcInstanceDetails struct {
6568
TrialValidity int64 `json:"trial_validity"`
6669
ArcID string `json:"arc_id"`
6770
CreatedAt int64 `json:"created_at"`
68-
Tier string `json:"tier"`
71+
Tier *Plan `json:"tier"`
6972
TierValidity int64 `json:"tier_validity"`
7073
TimeValidity int64 `json:"time_validity"`
7174
Metadata map[string]interface{} `json:"metadata"`
@@ -88,6 +91,18 @@ func BillingMiddleware(next http.Handler) http.Handler {
8891
})
8992
}
9093

94+
// Returns the arc instance by ID
95+
func getArcInstanceByID(arcID string, arcInstances []ArcInstanceDetails) ArcInstanceDetails {
96+
var arcInstance ArcInstanceDetails
97+
for _, instance := range arcInstances {
98+
if instance.ArcID == arcID {
99+
arcInstance = instance
100+
break
101+
}
102+
}
103+
return arcInstance
104+
}
105+
91106
func getArcInstance(arcID string) (ArcInstance, error) {
92107
arcInstance := ArcInstance{}
93108
response := ArcInstanceResponse{}
@@ -109,8 +124,10 @@ func getArcInstance(arcID string) (ArcInstance, error) {
109124
}
110125
err = json.Unmarshal(body, &response)
111126
if len(response.ArcInstances) != 0 {
112-
arcInstance.SubscriptionID = response.ArcInstances[0].SubscriptionID
113-
TimeValidity = response.ArcInstances[0].TimeValidity
127+
arcInstanceByID := getArcInstanceByID(arcID, response.ArcInstances)
128+
arcInstance.SubscriptionID = arcInstanceByID.SubscriptionID
129+
TimeValidity = arcInstanceByID.TimeValidity
130+
Tier = arcInstanceByID.Tier
114131
} else {
115132
return arcInstance, errors.New("No valid instance found for the provided ARC_ID")
116133
}
@@ -148,7 +165,10 @@ func getArcClusterInstance(clusterID string) (ArcInstance, error) {
148165
return arcInstance, err
149166
}
150167
if len(response.ArcInstances) != 0 {
151-
arcInstance.SubscriptionID = response.ArcInstances[0].SubscriptionID
168+
arcInstanceByID := getArcInstanceByID(clusterID, response.ArcInstances)
169+
arcInstance.SubscriptionID = arcInstanceByID.SubscriptionID
170+
TimeValidity = arcInstanceByID.TimeValidity
171+
Tier = arcInstanceByID.Tier
152172
} else {
153173
return arcInstance, errors.New("No valid instance found for the provided CLUSTER_ID")
154174
}
@@ -307,7 +327,7 @@ func ReportHostedArcUsage() {
307327
usageBody.Quantity = NodeCount
308328
response, err1 := reportClusterUsageRequest(usageBody)
309329
if err1 != nil {
310-
log.Println("Please contact support@appbase.io with your ARC_ID or registered e-mail address. Usage is not getting reported: ", err1)
330+
log.Println("Please contact support@appbase.io with your CLUSTER_ID or registered e-mail address. Usage is not getting reported: ", err1)
311331
}
312332

313333
// TimeValidity = response.TimeValidity

util/plans.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package util
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
)
7+
8+
// An enum having a list of valid plans
9+
type Plan int
10+
11+
const (
12+
ArcBasic Plan = iota
13+
ArcEnterprise
14+
HostedArcEnterprise
15+
ProductionFirst
16+
ProductionSecond
17+
ProductionThird
18+
)
19+
20+
// String is the implementation of Stringer interface that returns the string representation of Plan type.
21+
func (o Plan) String() string {
22+
return [...]string{
23+
"arc-basic",
24+
"arc-enterprise",
25+
"hosted-arc-enterprise",
26+
"2019-production-1",
27+
"2019-production-2",
28+
"2019-production-3",
29+
}[o]
30+
}
31+
32+
// UnmarshalJSON is the implementation of the Unmarshaler interface for unmarshaling Plan type.
33+
func (o *Plan) UnmarshalJSON(bytes []byte) error {
34+
var plan string
35+
err := json.Unmarshal(bytes, &plan)
36+
if err != nil {
37+
return err
38+
}
39+
switch plan {
40+
case ArcBasic.String():
41+
*o = ArcBasic
42+
case ArcEnterprise.String():
43+
*o = ArcEnterprise
44+
case HostedArcEnterprise.String():
45+
*o = HostedArcEnterprise
46+
case ProductionFirst.String():
47+
*o = ProductionFirst
48+
case ProductionSecond.String():
49+
*o = ProductionSecond
50+
case ProductionThird.String():
51+
*o = ProductionThird
52+
default:
53+
return fmt.Errorf("invalid plan encountered: %v", plan)
54+
}
55+
return nil
56+
}
57+
58+
// MarshalJSON is the implementation of the Marshaler interface for marshaling Plan type.
59+
func (o Plan) MarshalJSON() ([]byte, error) {
60+
var plan string
61+
switch o {
62+
case ArcBasic:
63+
plan = ArcBasic.String()
64+
case ArcEnterprise:
65+
plan = ArcEnterprise.String()
66+
case HostedArcEnterprise:
67+
plan = HostedArcEnterprise.String()
68+
case ProductionFirst:
69+
plan = ProductionFirst.String()
70+
case ProductionSecond:
71+
plan = ProductionSecond.String()
72+
case ProductionThird:
73+
plan = ProductionThird.String()
74+
default:
75+
return nil, fmt.Errorf("invalid plan encountered: %v", o)
76+
}
77+
return json.Marshal(plan)
78+
}
79+
80+
// A util function to validate the user's plan against the restricted plans
81+
func ValidatedPlans(restrictedPlans []Plan) bool {
82+
for _, restrictedPlan := range restrictedPlans {
83+
if Billing == "true" && Tier.String() == restrictedPlan.String() {
84+
return false
85+
}
86+
}
87+
return true
88+
}

0 commit comments

Comments
 (0)