Skip to content

Commit 4555db8

Browse files
committed
fix: fix owner/user fields and auth for resources
1 parent 2889b7f commit 4555db8

6 files changed

Lines changed: 63 additions & 21 deletions

File tree

controllers/message_answer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ func generateMessageAnswer(id string, responseWriter http.ResponseWriter, host s
309309
if questionMessage != nil {
310310
webSearchEnabled = questionMessage.WebSearchEnabled
311311
}
312-
mcpToolSet = object.MergeMcpTools(mcpToolSet, store, webSearchEnabled, lang)
312+
mcpToolSet = object.MergeMcpTools(mcpToolSet, store, webSearchEnabled, message.User, lang)
313313

314314
var knowledge []*model.RawMessage
315315
var vectorScores []object.VectorScore

controllers/resource.go

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,39 @@ func (c *ApiController) GetGlobalResources() {
4141
sortField := c.Input().Get("sortField")
4242
sortOrder := c.Input().Get("sortOrder")
4343

44+
userName, ok := c.RequireSignedIn()
45+
if !ok {
46+
return
47+
}
48+
49+
filterUser := ""
50+
if !c.IsAdmin() {
51+
filterUser = userName
52+
}
53+
4454
if limit == "" || page == "" {
45-
resources, err := object.GetGlobalResources(owner)
55+
var resources []*object.Resource
56+
var err error
57+
if filterUser == "" {
58+
resources, err = object.GetGlobalResources(owner)
59+
} else {
60+
resources, err = object.GetResources(owner, filterUser)
61+
}
4662
if err != nil {
4763
c.ResponseError(err.Error())
4864
return
4965
}
5066
c.ResponseOk(resources)
5167
} else {
52-
if !c.RequireAdmin() {
53-
return
54-
}
55-
5668
limitInt := util.ParseInt(limit)
57-
count, err := object.GetResourceCount(owner, field, value)
69+
count, err := object.GetResourceCount(owner, filterUser, field, value)
5870
if err != nil {
5971
c.ResponseError(err.Error())
6072
return
6173
}
6274

6375
paginator := pagination.SetPaginator(c.Ctx, limitInt, count)
64-
resources, err := object.GetPaginationResources(owner, paginator.Offset(), limitInt, field, value, sortField, sortOrder)
76+
resources, err := object.GetPaginationResources(owner, filterUser, paginator.Offset(), limitInt, field, value, sortField, sortOrder)
6577
if err != nil {
6678
c.ResponseError(err.Error())
6779
return
@@ -81,12 +93,22 @@ func (c *ApiController) GetGlobalResources() {
8193
func (c *ApiController) GetResource() {
8294
id := c.Input().Get("id")
8395

96+
userName, ok := c.RequireSignedIn()
97+
if !ok {
98+
return
99+
}
100+
84101
resource, err := object.GetResource(id)
85102
if err != nil {
86103
c.ResponseError(err.Error())
87104
return
88105
}
89106

107+
if resource != nil && !c.IsAdmin() && resource.User != userName {
108+
c.ResponseError(c.T("auth:Unauthorized operation"))
109+
return
110+
}
111+
90112
c.ResponseOk(resource)
91113
}
92114

@@ -149,13 +171,23 @@ func (c *ApiController) AddResource() {
149171
// @Success 200 {object} controllers.Response The Response object
150172
// @router /delete-resource [post]
151173
func (c *ApiController) DeleteResource() {
174+
userName, ok := c.RequireSignedIn()
175+
if !ok {
176+
return
177+
}
178+
152179
var resource object.Resource
153180
err := json.NewDecoder(c.Ctx.Request.Body).Decode(&resource)
154181
if err != nil {
155182
c.ResponseError(err.Error())
156183
return
157184
}
158185

186+
if !c.IsAdmin() && resource.User != userName {
187+
c.ResponseError(c.T("auth:Unauthorized operation"))
188+
return
189+
}
190+
159191
err = object.DeleteResourceFile(&resource, c.GetAcceptLanguage())
160192
if err != nil {
161193
c.ResponseError(err.Error())

object/merge_agent_tools.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020
"github.com/the-open-agent/openagent/util"
2121
)
2222

23-
func buildMergedBuiltinRegistry(store *Store, lang string) *tool.ToolRegistry {
23+
func buildMergedBuiltinRegistry(store *Store, user, lang string) *tool.ToolRegistry {
2424
reg := tool.NewToolRegistry()
2525

2626
if store == nil {
@@ -56,7 +56,7 @@ func buildMergedBuiltinRegistry(store *Store, lang string) *tool.ToolRegistry {
5656
}
5757
for _, bt := range tp.BuiltinTools() {
5858
wrapped := wrapSnapshotBuiltin(store.Owner, bt)
59-
wrapped = wrapGeneratedResourceBuiltin(wrapped)
59+
wrapped = wrapGeneratedResourceBuiltin(wrapped, store.Owner, user)
6060
reg.RegisterTool(wrapped)
6161
}
6262
}
@@ -66,15 +66,15 @@ func buildMergedBuiltinRegistry(store *Store, lang string) *tool.ToolRegistry {
6666

6767
// MergeMcpTools merges builtin tools (from the store's tool list) and the
6868
// web-search flag into an existing McpToolSet, creating one if needed.
69-
func MergeMcpTools(mcpToolSet *mcp.ToolSet, store *Store, webSearchEnabled bool, lang string) *mcp.ToolSet {
69+
func MergeMcpTools(mcpToolSet *mcp.ToolSet, store *Store, webSearchEnabled bool, user, lang string) *mcp.ToolSet {
7070
if webSearchEnabled {
7171
if mcpToolSet == nil {
7272
mcpToolSet = &mcp.ToolSet{}
7373
}
7474
mcpToolSet.WebSearchEnabled = true
7575
}
7676

77-
reg := buildMergedBuiltinRegistry(store, lang)
77+
reg := buildMergedBuiltinRegistry(store, user, lang)
7878
allTools := reg.GetToolsAsProtocolTools()
7979
if len(allTools) == 0 {
8080
return mcpToolSet

object/message_tool.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func buildToolSetForBuiltinTool(toolName string, lang string) (*mcp.ToolSet, err
4545
reg := tool.NewToolRegistry()
4646
for _, t := range tp.BuiltinTools() {
4747
wrapped := wrapSnapshotBuiltin("admin", t)
48-
wrapped = wrapGeneratedResourceBuiltin(wrapped)
48+
wrapped = wrapGeneratedResourceBuiltin(wrapped, "admin", "")
4949
reg.RegisterTool(wrapped)
5050
}
5151

object/resource.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,20 @@ func DeleteResource(resource *Resource) (bool, error) {
148148
return affected != 0, nil
149149
}
150150

151-
func GetResourceCount(owner, field, value string) (int64, error) {
151+
func GetResourceCount(owner, user, field, value string) (int64, error) {
152152
session := GetDbSession(owner, -1, -1, field, value, "", "")
153+
if user != "" {
154+
session = session.And("user = ?", user)
155+
}
153156
return session.Count(&Resource{})
154157
}
155158

156-
func GetPaginationResources(owner string, offset, limit int, field, value, sortField, sortOrder string) ([]*Resource, error) {
159+
func GetPaginationResources(owner, user string, offset, limit int, field, value, sortField, sortOrder string) ([]*Resource, error) {
157160
resources := []*Resource{}
158161
session := GetDbSession(owner, offset, limit, field, value, sortField, sortOrder)
162+
if user != "" {
163+
session = session.And("user = ?", user)
164+
}
159165
err := session.Find(&resources)
160166
if err != nil {
161167
return resources, err

object/resource_archive_tool.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,22 @@ import (
2929

3030
type generatedResourceArchiveBuiltinTool struct {
3131
inner tool.BuiltinTool
32+
owner string
33+
user string
3234
}
3335

34-
var archiveGeneratedResourceFile = archiveGeneratedResourceFileToStorage
36+
var archiveGeneratedResourceFile = func(owner, user, path string) (*Resource, error) {
37+
return archiveGeneratedResourceFileToStorage(owner, user, path)
38+
}
3539

36-
func wrapGeneratedResourceBuiltin(builtin tool.BuiltinTool) tool.BuiltinTool {
40+
func wrapGeneratedResourceBuiltin(builtin tool.BuiltinTool, owner, user string) tool.BuiltinTool {
3741
if builtin == nil {
3842
return nil
3943
}
4044
if !isGeneratedResourceTool(builtin.GetName()) {
4145
return builtin
4246
}
43-
return &generatedResourceArchiveBuiltinTool{inner: builtin}
47+
return &generatedResourceArchiveBuiltinTool{inner: builtin, owner: owner, user: user}
4448
}
4549

4650
func isGeneratedResourceTool(toolName string) bool {
@@ -75,7 +79,7 @@ func (t *generatedResourceArchiveBuiltinTool) Execute(ctx context.Context, argum
7579
return result, innerErr
7680
}
7781

78-
resource, err := archiveGeneratedResourceFile(path)
82+
resource, err := archiveGeneratedResourceFile(t.owner, t.user, path)
7983
if err != nil {
8084
appendGeneratedResourceArchiveText(result, fmt.Sprintf("Resource archive warning: file was created but could not be saved to Resources: %s", err.Error()))
8185
return result, innerErr
@@ -107,7 +111,7 @@ func resourceArchiveStringArg(arguments map[string]interface{}, key string) stri
107111
return strings.TrimSpace(value)
108112
}
109113

110-
func archiveGeneratedResourceFileToStorage(path string) (*Resource, error) {
114+
func archiveGeneratedResourceFileToStorage(owner, user, path string) (*Resource, error) {
111115
info, err := os.Stat(path)
112116
if err != nil {
113117
return nil, err
@@ -139,7 +143,7 @@ func archiveGeneratedResourceFileToStorage(path string) (*Resource, error) {
139143
return nil, err
140144
}
141145

142-
resource := NewResourceFromUpload("admin", "", "generated", fileName, fileType, ext, fileUrl, storageName, len(fileBytes), "", "")
146+
resource := NewResourceFromUpload(owner, user, "generated", fileName, fileType, ext, fileUrl, storageName, len(fileBytes), "", "")
143147
if _, err = AddResource(resource); err != nil {
144148
return nil, err
145149
}

0 commit comments

Comments
 (0)