Skip to content

Commit f7c9a28

Browse files
[API-49] Give correct access based on self id and grants (#63)
This PR fixes 3 issues with access: 1. Currently you will get a stream URL back even if you don't have access (oops!) 2. Access checker now grants access if authedUserId = myId 3. Access checker now grants access if authedWallet is a granted app/manager This feels a bit prop drill-y but I couldn't easily come up with something better. Would be great to get feedback.
1 parent 0c32664 commit f7c9a28

24 files changed

Lines changed: 378 additions & 137 deletions

api/auth_middleware.go

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package api
22

33
import (
4+
"context"
45
"fmt"
56
"strings"
67

@@ -62,14 +63,14 @@ func (app *ApiServer) recoverAuthorityFromSignatureHeaders(c *fiber.Ctx) (int32,
6263
}
6364

6465
// Checks if authedWallet is authorized to act on behalf of userId
65-
func (app *ApiServer) isAuthorizedRequest(c *fiber.Ctx, userId int32, authedWallet string) bool {
66+
func (app *ApiServer) isAuthorizedRequest(ctx context.Context, userId int32, authedWallet string) bool {
6667
cacheKey := fmt.Sprintf("%d:%s", userId, authedWallet)
6768
if hit, ok := app.resolveGrantCache.Get(cacheKey); ok {
6869
return hit
6970
}
7071

7172
var isAuthorized bool
72-
err := app.pool.QueryRow(c.Context(), `
73+
err := app.pool.QueryRow(ctx, `
7374
SELECT EXISTS (
7475
SELECT 1
7576
FROM grants
@@ -99,37 +100,56 @@ func (app *ApiServer) getAuthedWallet(c *fiber.Ctx) string {
99100
}
100101

101102
// Middleware to set authedUserId and authedWallet in context
103+
// Returns a 403 if either
104+
// - the user is not authorized to act on behalf of "myId"
105+
// - the user is not authorized to act on behalf of "requestedWallet"
102106
func (app *ApiServer) authMiddleware(c *fiber.Ctx) error {
103107
userId, wallet := app.recoverAuthorityFromSignatureHeaders(c)
104108
c.Locals("authedUserId", userId)
105109
c.Locals("authedWallet", wallet)
106110

111+
myId := app.getMyId(c)
112+
requestedWallet := c.Params("wallet")
113+
114+
// Not authorized to act on behalf of myId
115+
if myId != 0 {
116+
if userId != myId && !app.isAuthorizedRequest(c.Context(), myId, wallet) {
117+
return fiber.NewError(
118+
fiber.StatusForbidden,
119+
fmt.Sprintf(
120+
"You are not authorized to make this request authedUserId=%d authedWallet=%s myId=%d",
121+
userId,
122+
wallet,
123+
myId,
124+
),
125+
)
126+
}
127+
}
128+
129+
// Not authorized to act on behalf of requestedWallet
130+
if requestedWallet != "" && wallet != "" {
131+
if !strings.EqualFold(requestedWallet, wallet) {
132+
return fiber.NewError(
133+
fiber.StatusForbidden,
134+
fmt.Sprintf(
135+
"You are not authorized to make this request authedUserId=%d authedWallet=%s requestedWallet=%s",
136+
userId,
137+
wallet,
138+
requestedWallet,
139+
),
140+
)
141+
}
142+
}
143+
107144
return c.Next()
108145
}
109146

110-
// Middleware that asserts the authedUserId is valid and is the same as the userId in
111-
// the request path or a managed user of the authedUserId
147+
// Middleware that asserts that there is an authedUserId
112148
func (app *ApiServer) requireAuthMiddleware(c *fiber.Ctx) error {
113149
authedUserId := app.getAuthedUserId(c)
114-
authedWallet := app.getAuthedWallet(c)
115-
myId := app.getMyId(c)
116-
wallet := c.Params("wallet")
117150
if authedUserId == 0 {
118151
return fiber.NewError(fiber.StatusUnauthorized, "You must be logged in to make this request")
119152
}
120153

121-
if myId != 0 && myId == authedUserId {
122-
return c.Next()
123-
}
124-
125-
if wallet != "" && strings.EqualFold(wallet, authedWallet) {
126-
return c.Next()
127-
}
128-
129-
if app.isAuthorizedRequest(c, myId, authedWallet) {
130-
return c.Next()
131-
}
132-
133-
msg := fmt.Sprintf("You are not authorized to make this request authedUserId=%d authedWallet=%s myId=%d wallet=%s", authedUserId, authedWallet, myId, wallet)
134-
return fiber.NewError(fiber.StatusForbidden, msg)
154+
return c.Next()
135155
}

api/auth_middleware_test.go

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,52 +29,81 @@ func TestRecoverAuthorityFromSignatureHeaders(t *testing.T) {
2929
assert.Equal(t, "0x7d273271690538cf855e5b3002a0dd8c154bb060", wallet)
3030
}
3131

32-
func TestRequireAuthMiddleware(t *testing.T) {
33-
// Create a dummy endpoint to test the requireAuthMiddleware
32+
func TestAuthorized(t *testing.T) {
33+
// Create a dummy endpoint to test the authMiddleware
3434
testApp := fiber.New()
35-
testApp.Get("/", app.resolveMyIdMiddleware, app.authMiddleware, app.requireAuthMiddleware, func(c *fiber.Ctx) error {
35+
testApp.Get("/", app.resolveMyIdMiddleware, app.authMiddleware, func(c *fiber.Ctx) error {
36+
return c.SendStatus(fiber.StatusOK)
37+
})
38+
testApp.Get("/account/:wallet", app.resolveMyIdMiddleware, app.authMiddleware, func(c *fiber.Ctx) error {
3639
return c.SendStatus(fiber.StatusOK)
3740
})
38-
39-
// Unauthorized when no auth headers
40-
req1 := httptest.NewRequest("GET", "/", nil)
41-
res, err := testApp.Test(req1, -1)
42-
assert.NoError(t, err)
43-
assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode)
4441

4542
// Forbidden when not authorized
46-
req2 := httptest.NewRequest("GET", "/?user_id=1", nil)
43+
req := httptest.NewRequest("GET", "/?user_id=7eP5n", nil)
4744
// wallet: 0x681c616ae836ceca1effe00bd07f2fdbf9a082bc
48-
req2.Header.Set("Encoded-Data-Message", "signature:1745543704165")
49-
req2.Header.Set("Encoded-Data-Signature", "0x4af765948dccd72026f1059a59c7a6a1172628255d7d387d1590c0fe43961c5908fc6011443805ca0dbd39156300c04dc21bbfa9adce50acea9ad29a7e2fde2a1b")
50-
res, err = testApp.Test(req2, -1)
45+
req.Header.Set("Encoded-Data-Message", "signature:1745543704165")
46+
req.Header.Set("Encoded-Data-Signature", "0x4af765948dccd72026f1059a59c7a6a1172628255d7d387d1590c0fe43961c5908fc6011443805ca0dbd39156300c04dc21bbfa9adce50acea9ad29a7e2fde2a1b")
47+
res, err := testApp.Test(req, -1)
5148
assert.NoError(t, err)
5249
assert.Equal(t, fiber.StatusForbidden, res.StatusCode)
5350

5451
// Forbidden when grant is revoked
55-
req3 := httptest.NewRequest("GET", "/?user_id=1", nil)
52+
req = httptest.NewRequest("GET", "/?user_id=7eP5n", nil)
5653
// wallet: 0xc451c1f8943b575158310552b41230c61844a1c1
57-
req3.Header.Set("Encoded-Data-Message", "signature:1745542789211")
58-
req3.Header.Set("Encoded-Data-Signature", "0xffd5f92c0d253c7222cd407cf3398fac664530ef968bd4435ea698ba1daee1d73353330848b65d212eeeaae9f41e177e49078c4efa1131e5e517090626f6dd961c")
59-
res, err = testApp.Test(req3, -1)
54+
req.Header.Set("Encoded-Data-Message", "signature:1745542789211")
55+
req.Header.Set("Encoded-Data-Signature", "0xffd5f92c0d253c7222cd407cf3398fac664530ef968bd4435ea698ba1daee1d73353330848b65d212eeeaae9f41e177e49078c4efa1131e5e517090626f6dd961c")
56+
res, err = testApp.Test(req, -1)
6057
assert.NoError(t, err)
6158
assert.Equal(t, fiber.StatusForbidden, res.StatusCode)
6259

6360
// Authorized when grant is approved
64-
req4 := httptest.NewRequest("GET", "/?user_id=1", nil)
61+
req = httptest.NewRequest("GET", "/?user_id=7eP5n", nil)
6562
// wallet: 0x5f1a372b28956c8363f8bc3a231a6e9e1186ead8
66-
req4.Header.Set("Encoded-Data-Message", "signature:1745544459796")
67-
req4.Header.Set("Encoded-Data-Signature", "0x1c9cb405d8437d28ff5596918551f7a45f981e81618d65ee10892313292a8c7a325af002231d115b28ca2d244b082abe1bde4a7d9610f8140d3738a9be5c4fd91b")
68-
res, err = testApp.Test(req4, -1)
63+
req.Header.Set("Encoded-Data-Message", "signature:1745544459796")
64+
req.Header.Set("Encoded-Data-Signature", "0x1c9cb405d8437d28ff5596918551f7a45f981e81618d65ee10892313292a8c7a325af002231d115b28ca2d244b082abe1bde4a7d9610f8140d3738a9be5c4fd91b")
65+
res, err = testApp.Test(req, -1)
6966
assert.NoError(t, err)
7067
assert.Equal(t, fiber.StatusOK, res.StatusCode)
7168

7269
// Authorized when own user
73-
req5 := httptest.NewRequest("GET", "/?user_id=1", nil)
70+
req = httptest.NewRequest("GET", "/?user_id=7eP5n", nil)
7471
// wallet: 0x7d273271690538cf855e5b3002a0dd8c154bb060
75-
req5.Header.Set("Encoded-Data-Message", "signature:1744763856446")
76-
req5.Header.Set("Encoded-Data-Signature", "0xbb202be3a7f3a0aa22c1458ef6a3f2f8360fb86791c7b137e8562df0707825c11fa1db01096efd2abc5e6613c4d1e8d4ae1e2b993abdd555fe270c1b17bff0d21c")
77-
res, err = testApp.Test(req5, -1)
72+
req.Header.Set("Encoded-Data-Message", "signature:1744763856446")
73+
req.Header.Set("Encoded-Data-Signature", "0xbb202be3a7f3a0aa22c1458ef6a3f2f8360fb86791c7b137e8562df0707825c11fa1db01096efd2abc5e6613c4d1e8d4ae1e2b993abdd555fe270c1b17bff0d21c")
74+
res, err = testApp.Test(req, -1)
75+
assert.NoError(t, err)
76+
assert.Equal(t, fiber.StatusOK, res.StatusCode)
77+
78+
// Forbidden when not authorized to act on behalf of requested wallet
79+
req = httptest.NewRequest("GET", "/account/0x111c616ae836ceca1effe00bd07f2fdbf9a082bc", nil)
80+
// wallet: 0x681c616ae836ceca1effe00bd07f2fdbf9a082bc
81+
req.Header.Set("Encoded-Data-Message", "signature:1745543704165")
82+
req.Header.Set("Encoded-Data-Signature", "0x4af765948dccd72026f1059a59c7a6a1172628255d7d387d1590c0fe43961c5908fc6011443805ca0dbd39156300c04dc21bbfa9adce50acea9ad29a7e2fde2a1b")
83+
res, err = testApp.Test(req, -1)
84+
assert.NoError(t, err)
85+
assert.Equal(t, fiber.StatusForbidden, res.StatusCode)
86+
87+
// Authorized when requesting wallet matches authed wallet
88+
req = httptest.NewRequest("GET", "/account/0x681c616ae836ceca1effe00bd07f2fdbf9a082bc", nil)
89+
// wallet: 0x681c616ae836ceca1effe00bd07f2fdbf9a082bc
90+
req.Header.Set("Encoded-Data-Message", "signature:1745543704165")
91+
req.Header.Set("Encoded-Data-Signature", "0x4af765948dccd72026f1059a59c7a6a1172628255d7d387d1590c0fe43961c5908fc6011443805ca0dbd39156300c04dc21bbfa9adce50acea9ad29a7e2fde2a1b")
92+
res, err = testApp.Test(req, -1)
7893
assert.NoError(t, err)
7994
assert.Equal(t, fiber.StatusOK, res.StatusCode)
8095
}
96+
97+
func TestRequireAuthMiddleware(t *testing.T) {
98+
// Create a dummy endpoint to test the requireAuthMiddleware
99+
testApp := fiber.New()
100+
testApp.Get("/", app.resolveMyIdMiddleware, app.authMiddleware, app.requireAuthMiddleware, func(c *fiber.Ctx) error {
101+
return c.SendStatus(fiber.StatusOK)
102+
})
103+
104+
// Unauthorized when no auth headers
105+
req1 := httptest.NewRequest("GET", "/", nil)
106+
res, err := testApp.Test(req1, -1)
107+
assert.NoError(t, err)
108+
assert.Equal(t, fiber.StatusUnauthorized, res.StatusCode)
109+
}

api/dbv1/access.go

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,33 @@ type Access struct {
99
Download bool `json:"download"`
1010
}
1111

12-
func (q *Queries) GetTrackAccess(ctx context.Context, myId int32, conditions *AccessGate, track *GetTracksRow, user *FullUser) bool {
12+
func (q *Queries) GetTrackAccess(
13+
ctx context.Context,
14+
myId int32,
15+
conditions *AccessGate,
16+
track *GetTracksRow,
17+
user *FullUser,
18+
) bool {
19+
// No track? no access
1320
if track == nil || user == nil {
1421
return false
1522
}
1623

17-
// no conditions means open access
24+
// No conditions means open access
1825
if conditions == nil {
1926
return true
2027
}
2128

29+
// No myId? no access. we need to know who you are if there are conditions.
30+
if myId == 0 {
31+
return false
32+
}
33+
34+
// You always have access to your own content
35+
if myId == user.UserID {
36+
return true
37+
}
38+
2239
switch {
2340
case conditions.FollowUserID != nil:
2441
return user.DoesCurrentUserFollow
@@ -114,11 +131,28 @@ func (q *Queries) GetTrackAccess(ctx context.Context, myId int32, conditions *Ac
114131
return false
115132
}
116133

117-
func (q *Queries) GetPlaylistAccess(ctx context.Context, myId int32, conditions *AccessGate, playlist *GetPlaylistsRow, user *FullUser) bool {
134+
func (q *Queries) GetPlaylistAccess(
135+
ctx context.Context,
136+
myId int32,
137+
conditions *AccessGate,
138+
playlist *GetPlaylistsRow,
139+
user *FullUser,
140+
) bool {
141+
// No playlist? no access.
142+
if playlist == nil || user == nil {
143+
return false
144+
}
145+
146+
// no conditions means open access
118147
if conditions == nil {
119148
return true
120149
}
121150

151+
// I always have access to my own content
152+
if myId != 0 && myId == user.UserID {
153+
return true
154+
}
155+
122156
switch {
123157
case conditions.FollowUserID != nil:
124158
return user.DoesCurrentUserFollow

api/dbv1/full_playlists.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ import (
88
"github.com/jackc/pgx/v5/pgtype"
99
)
1010

11+
type FullPlaylistsParams struct {
12+
GetPlaylistsParams
13+
}
14+
1115
type FullPlaylist struct {
1216
GetPlaylistsRow
1317

@@ -31,8 +35,8 @@ type FullPlaylistContentsItem struct {
3135
MetadataTime int64 `json:"metadata_timestamp"`
3236
}
3337

34-
func (q *Queries) FullPlaylistsKeyed(ctx context.Context, arg GetPlaylistsParams) (map[int32]FullPlaylist, error) {
35-
rawPlaylists, err := q.GetPlaylists(ctx, arg)
38+
func (q *Queries) FullPlaylistsKeyed(ctx context.Context, arg FullPlaylistsParams) (map[int32]FullPlaylist, error) {
39+
rawPlaylists, err := q.GetPlaylists(ctx, arg.GetPlaylistsParams)
3640
if err != nil {
3741
return nil, err
3842
}
@@ -51,7 +55,7 @@ func (q *Queries) FullPlaylistsKeyed(ctx context.Context, arg GetPlaylistsParams
5155
loaded, err := q.Parallel(ctx, ParallelParams{
5256
UserIds: userIds,
5357
TrackIds: trackIds,
54-
MyID: arg.MyID,
58+
MyID: arg.MyID.(int32),
5559
})
5660
if err != nil {
5761
return nil, err
@@ -88,7 +92,12 @@ func (q *Queries) FullPlaylistsKeyed(ctx context.Context, arg GetPlaylistsParams
8892
}
8993

9094
// For playlists, download access is the same as stream access
91-
streamAccess := q.GetPlaylistAccess(ctx, arg.MyID.(int32), playlist.StreamConditions, &playlist, &user)
95+
streamAccess := q.GetPlaylistAccess(
96+
ctx,
97+
arg.MyID.(int32),
98+
playlist.StreamConditions,
99+
&playlist,
100+
&user)
92101
downloadAccess := streamAccess
93102

94103
var playlistType string
@@ -120,7 +129,7 @@ func (q *Queries) FullPlaylistsKeyed(ctx context.Context, arg GetPlaylistsParams
120129
return playlistMap, nil
121130
}
122131

123-
func (q *Queries) FullPlaylists(ctx context.Context, arg GetPlaylistsParams) ([]FullPlaylist, error) {
132+
func (q *Queries) FullPlaylists(ctx context.Context, arg FullPlaylistsParams) ([]FullPlaylist, error) {
124133
playlistMap, err := q.FullPlaylistsKeyed(ctx, arg)
125134
if err != nil {
126135
return nil, err

0 commit comments

Comments
 (0)