Skip to content

Commit 655f992

Browse files
committed
fix(gateway,context): 修复 ACL 方法名互换和 nil source 保护
- 修复 handleSessionAssetRead 中 ACL 检查使用了 sessionAssetDeleteMethod 的 bug - 修复 handleSessionAssetDelete 中 ACL 检查使用了 sessionAssetReadMethod 的 bug - 为 newStablePromptSources 添加 nil source 过滤,防止 collectPromptSections panic - 补充 ACL 独立生效测试(read-only 和 delete-only ACL 场景) - 补充 nil source 和 mixed nil/valid source 测试 - 补充 multi_workspace Delete 不支持 SessionAssetPort 的错误返回测试 - 补充图片投影空消息和混合消息的边界测试
1 parent f5700b3 commit 655f992

6 files changed

Lines changed: 223 additions & 5 deletions

File tree

internal/context/builder.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ func newStablePromptSources(extra ...SectionSource) []promptSectionSource {
1919
newRulesPromptSource(nil),
2020
}
2121
for _, src := range extra {
22-
sources = append(sources, src)
22+
if src != nil {
23+
sources = append(sources, src)
24+
}
2325
}
2426
return sources
2527
}

internal/context/builder_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,40 @@ func TestNewConfiguredBuilder(t *testing.T) {
645645
}
646646
})
647647

648+
t.Run("nil extra source is safely ignored", func(t *testing.T) {
649+
t.Parallel()
650+
builder := NewConfiguredBuilder(nil)
651+
input := BuildInput{
652+
Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}},
653+
Metadata: testMetadata(t.TempDir()),
654+
}
655+
result, err := builder.Build(stdcontext.Background(), input)
656+
if err != nil {
657+
t.Fatalf("Build() with nil source error = %v", err)
658+
}
659+
if result.SystemPrompt == "" {
660+
t.Fatal("expected non-empty system prompt even with nil extra source")
661+
}
662+
})
663+
664+
t.Run("mixed nil and valid extra sources", func(t *testing.T) {
665+
t.Parallel()
666+
builder := NewConfiguredBuilder(nil, stubPromptSectionSource{
667+
sections: []promptSection{{Title: "Valid", Content: "valid section"}},
668+
}, nil)
669+
input := BuildInput{
670+
Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}},
671+
Metadata: testMetadata(t.TempDir()),
672+
}
673+
result, err := builder.Build(stdcontext.Background(), input)
674+
if err != nil {
675+
t.Fatalf("Build() with mixed nil/valid sources error = %v", err)
676+
}
677+
if !strings.Contains(result.SystemPrompt, "## Valid") {
678+
t.Fatal("expected valid section to be present while nil sources are ignored")
679+
}
680+
})
681+
648682
t.Run("multiple extra section sources are appended", func(t *testing.T) {
649683
builder := NewConfiguredBuilder(stubPromptSectionSource{
650684
sections: []promptSection{{Title: "First", Content: "first body"}},

internal/gateway/multi_workspace_runtime_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,36 @@ var _ SessionAssetPort = (*MultiWorkspaceRuntime)(nil)
864864
var _ ManagementRuntimePort = (*MultiWorkspaceRuntime)(nil)
865865
var _ PlanApprovalRuntimePort = (*MultiWorkspaceRuntime)(nil)
866866

867+
// recordingPortWithoutSessionAsset 嵌套 recordingPort 但不实现 SessionAssetPort,
868+
// 用于验证 MultiWorkspaceRuntime 在底层 runtime 不支持附件时的降级处理。
869+
type recordingPortWithoutSessionAsset struct{ *recordingPort }
870+
871+
func TestMultiWorkspaceRuntime_DeleteSessionAssetUnsupportedRuntime(t *testing.T) {
872+
idx, alpha, _ := setupIndex(t)
873+
builder := newTestBuilder()
874+
// 将 alpha 的 port 替换为不支持 SessionAssetPort 的版本
875+
alphaPort := newRecordingPort("alpha-no-asset")
876+
builder.ports[alpha.Path] = alphaPort
877+
mw := NewMultiWorkspaceRuntime(idx, alpha.Hash, func(ctx context.Context, workdir string) (RuntimePort, func() error, error) {
878+
port, cleanup, err := builder.build(ctx, workdir)
879+
if err != nil {
880+
return nil, nil, err
881+
}
882+
rp := port.(*recordingPort)
883+
return &recordingPortWithoutSessionAsset{rp}, cleanup, nil
884+
})
885+
t.Cleanup(func() { _ = mw.Close() })
886+
887+
alphaCtx := ctxWithHash(t, alpha.Hash)
888+
err := mw.DeleteSessionAsset(alphaCtx, DeleteSessionAssetInput{SessionID: "s-1", AssetID: "a-1"})
889+
if err == nil {
890+
t.Fatal("expected error when runtime does not implement SessionAssetPort")
891+
}
892+
if !errors.Is(err, ErrRuntimeUnavailable) {
893+
t.Fatalf("error = %v, want ErrRuntimeUnavailable", err)
894+
}
895+
}
896+
867897
// guard helper: ensure recordingPort builds correctly under sync access.
868898
func TestRecordingPort_Concurrent(t *testing.T) {
869899
p := newRecordingPort("c")

internal/gateway/network_server.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -496,8 +496,8 @@ func (s *NetworkServer) handleSessionAssetRead(writer http.ResponseWriter, reque
496496
http.Error(writer, "unauthorized", http.StatusUnauthorized)
497497
return
498498
}
499-
if !s.isHTTPControlPlaneMethodAllowed(sessionAssetDeleteMethod) {
500-
s.writeHTTPAccessDenied(writer, sessionAssetDeleteMethod)
499+
if !s.isHTTPControlPlaneMethodAllowed(sessionAssetReadMethod) {
500+
s.writeHTTPAccessDenied(writer, sessionAssetReadMethod)
501501
return
502502
}
503503
assetPort, ok := runtimePort.(SessionAssetPort)
@@ -539,8 +539,8 @@ func (s *NetworkServer) handleSessionAssetDelete(writer http.ResponseWriter, req
539539
http.Error(writer, "unauthorized", http.StatusUnauthorized)
540540
return
541541
}
542-
if !s.isHTTPControlPlaneMethodAllowed(sessionAssetReadMethod) {
543-
s.writeHTTPAccessDenied(writer, sessionAssetReadMethod)
542+
if !s.isHTTPControlPlaneMethodAllowed(sessionAssetDeleteMethod) {
543+
s.writeHTTPAccessDenied(writer, sessionAssetDeleteMethod)
544544
return
545545
}
546546
assetPort, ok := runtimePort.(SessionAssetPort)

internal/gateway/network_server_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,95 @@ func TestNetworkServerSessionAssetDeleteMissingIsIdempotent(t *testing.T) {
747747
}
748748
}
749749

750+
// TestNetworkServerSessionAssetACLIndependent 验证 GET 和 DELETE 的 ACL 检查相互独立:
751+
// 只允许 read 时 GET 通过但 DELETE 被拒;只允许 delete 时 DELETE 通过但 GET 被拒。
752+
func TestNetworkServerSessionAssetACLIndependent(t *testing.T) {
753+
t.Run("read allowed delete denied", func(t *testing.T) {
754+
readOnlyACL := &ControlPlaneACL{
755+
mode: ACLModeStrict,
756+
allow: map[RequestSource]map[string]struct{}{RequestSourceHTTP: {sessionAssetReadMethod: {}}},
757+
enabled: true,
758+
}
759+
runtimePort := &runtimePortEventStub{
760+
openAssetFn: func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) {
761+
return OpenSessionAssetResult{
762+
Reader: io.NopCloser(bytes.NewReader(gatewayMinimalPNGBytes())),
763+
Meta: SessionAssetMeta{SessionID: "session-1", AssetID: "asset-1", MimeType: "image/png"},
764+
}, nil
765+
},
766+
deleteAssetFn: func(context.Context, DeleteSessionAssetInput) error {
767+
t.Fatal("DeleteSessionAsset should not be called when ACL denies delete")
768+
return nil
769+
},
770+
}
771+
server := &NetworkServer{
772+
authenticator: staticTokenAuthenticator{token: "gateway-token"},
773+
acl: readOnlyACL,
774+
metrics: NewGatewayMetrics(),
775+
}
776+
handler := server.buildHandler(runtimePort)
777+
778+
// GET should succeed (read allowed)
779+
readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil)
780+
readRequest.Header.Set("Authorization", "Bearer gateway-token")
781+
readRecorder := httptest.NewRecorder()
782+
handler.ServeHTTP(readRecorder, readRequest)
783+
if readRecorder.Code != http.StatusOK {
784+
t.Fatalf("read status = %d, want %d", readRecorder.Code, http.StatusOK)
785+
}
786+
787+
// DELETE should be forbidden (delete denied)
788+
deleteRequest := httptest.NewRequest(http.MethodDelete, "/api/session-assets/session-1/asset-1", nil)
789+
deleteRequest.Header.Set("Authorization", "Bearer gateway-token")
790+
deleteRecorder := httptest.NewRecorder()
791+
handler.ServeHTTP(deleteRecorder, deleteRequest)
792+
if deleteRecorder.Code != http.StatusForbidden {
793+
t.Fatalf("delete status = %d, want %d", deleteRecorder.Code, http.StatusForbidden)
794+
}
795+
})
796+
797+
t.Run("delete allowed read denied", func(t *testing.T) {
798+
deleteOnlyACL := &ControlPlaneACL{
799+
mode: ACLModeStrict,
800+
allow: map[RequestSource]map[string]struct{}{RequestSourceHTTP: {sessionAssetDeleteMethod: {}}},
801+
enabled: true,
802+
}
803+
runtimePort := &runtimePortEventStub{
804+
openAssetFn: func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) {
805+
t.Fatal("OpenSessionAsset should not be called when ACL denies read")
806+
return OpenSessionAssetResult{}, nil
807+
},
808+
deleteAssetFn: func(context.Context, DeleteSessionAssetInput) error {
809+
return nil
810+
},
811+
}
812+
server := &NetworkServer{
813+
authenticator: staticTokenAuthenticator{token: "gateway-token"},
814+
acl: deleteOnlyACL,
815+
metrics: NewGatewayMetrics(),
816+
}
817+
handler := server.buildHandler(runtimePort)
818+
819+
// DELETE should succeed (delete allowed)
820+
deleteRequest := httptest.NewRequest(http.MethodDelete, "/api/session-assets/session-1/asset-1", nil)
821+
deleteRequest.Header.Set("Authorization", "Bearer gateway-token")
822+
deleteRecorder := httptest.NewRecorder()
823+
handler.ServeHTTP(deleteRecorder, deleteRequest)
824+
if deleteRecorder.Code != http.StatusOK {
825+
t.Fatalf("delete status = %d, want %d", deleteRecorder.Code, http.StatusOK)
826+
}
827+
828+
// GET should be forbidden (read denied)
829+
readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil)
830+
readRequest.Header.Set("Authorization", "Bearer gateway-token")
831+
readRecorder := httptest.NewRecorder()
832+
handler.ServeHTTP(readRecorder, readRequest)
833+
if readRecorder.Code != http.StatusForbidden {
834+
t.Fatalf("read status = %d, want %d", readRecorder.Code, http.StatusForbidden)
835+
}
836+
})
837+
}
838+
750839
func TestNetworkServerSessionAssetsRequireAssetPort(t *testing.T) {
751840
runtimePort := &runtimePortWithoutSessionAsset{RuntimePort: &runtimePortEventStub{}}
752841
server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}}

internal/runtime/runtime_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4932,6 +4932,69 @@ func TestProjectImagesForModelRequestCapabilityStates(t *testing.T) {
49324932
}
49334933
}
49344934

4935+
// TestProjectImagesForModelRequestEmptyMessages 验证空消息列表不会 panic 并正确返回空结果。
4936+
func TestProjectImagesForModelRequestEmptyMessages(t *testing.T) {
4937+
t.Parallel()
4938+
models := []providertypes.ModelDescriptor{{ID: "model-a", CapabilityHints: providertypes.ModelCapabilityHints{
4939+
ImageInput: providertypes.ModelCapabilityStateUnsupported,
4940+
}}}
4941+
projected, err := projectImagesForModelRequest("model-a", models, nil, []providertypes.ContentPart{providertypes.NewTextPart("hello")})
4942+
if err != nil {
4943+
t.Fatalf("unexpected error: %v", err)
4944+
}
4945+
if len(projected) != 0 {
4946+
t.Fatalf("expected empty projected messages, got %d", len(projected))
4947+
}
4948+
}
4949+
4950+
// TestProjectImagesForModelRequestMixedMessages 验证混合图片和文本的消息中只有图片被投影降级。
4951+
func TestProjectImagesForModelRequestMixedMessages(t *testing.T) {
4952+
t.Parallel()
4953+
messages := []providertypes.Message{
4954+
{
4955+
Role: providertypes.RoleUser,
4956+
Parts: []providertypes.ContentPart{providertypes.NewTextPart("text before image")},
4957+
},
4958+
{
4959+
Role: providertypes.RoleUser,
4960+
Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")},
4961+
},
4962+
{
4963+
Role: providertypes.RoleAssistant,
4964+
Parts: []providertypes.ContentPart{providertypes.NewTextPart("response")},
4965+
},
4966+
}
4967+
models := []providertypes.ModelDescriptor{{ID: "model-a", CapabilityHints: providertypes.ModelCapabilityHints{
4968+
ImageInput: providertypes.ModelCapabilityStateUnsupported,
4969+
}}}
4970+
projected, err := projectImagesForModelRequest("model-a", models, messages, []providertypes.ContentPart{providertypes.NewTextPart("current text")})
4971+
if err != nil {
4972+
t.Fatalf("unexpected error: %v", err)
4973+
}
4974+
if len(projected) != 3 {
4975+
t.Fatalf("expected 3 messages, got %d", len(projected))
4976+
}
4977+
// 第一条消息应保持纯文本不变
4978+
if projected[0].Parts[0].Kind != providertypes.ContentPartText {
4979+
t.Fatalf("first message should be text, got %s", projected[0].Parts[0].Kind)
4980+
}
4981+
// 第二条消息中的图片应被替换为占位文本
4982+
if messagesContainImages(projected) {
4983+
t.Fatal("expected no images in projected messages for unsupported model")
4984+
}
4985+
if !strings.Contains(renderPartsForTest(projected[1].Parts), historicalImageOmittedForModel) {
4986+
t.Fatal("expected historical image omitted placeholder in second message")
4987+
}
4988+
// 第三条消息应保持不变
4989+
if projected[2].Parts[0].Kind != providertypes.ContentPartText {
4990+
t.Fatalf("third message should be text, got %s", projected[2].Parts[0].Kind)
4991+
}
4992+
// 原始消息不应被修改
4993+
if !messagesContainImages(messages) {
4994+
t.Fatal("original messages should still contain images")
4995+
}
4996+
}
4997+
49354998
func newRuntimeConfigManager(t *testing.T) *config.Manager {
49364999
return newRuntimeConfigManagerWithProviderEnvs(t, nil)
49375000
}

0 commit comments

Comments
 (0)