Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions pkg/plugins/nemo/request_guard.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (p *NemoRequestGuardPlugin) ProcessRequest(ctx context.Context, _ *framewor
// extractMessages returns user-supplied text as a message slice suitable for NeMo's
// OpenAI-compatible chat endpoint. It supports two payload formats:
//
// 1. OpenAI chat: top-level "messages" array → forwards all non-system messages.
// 1. OpenAI chat: top-level "messages" array → forwards all messages.
// 2. MCP JSON-RPC: {"jsonrpc":"2.0","params":{"arguments":{…}}} → concatenates
// all string argument values into a single user message.
//
Expand All @@ -148,25 +148,24 @@ func extractMessages(body map[string]any) ([]map[string]string, error) {
return nil, nil // not an inference request (e.g. API key management, model listing)
}

// extractOpenAIMessages parses an OpenAI-style "messages" value. System messages are
// filtered. TrustyAI does not support this role. All other roles are forwarded so NeMo can evaluate
// the full conversation context.
// extractOpenAIMessages parses an OpenAI-style "messages" value. All messages are forwarded
// so NeMo can evaluate the full conversation context.
func extractOpenAIMessages(raw any) ([]map[string]string, error) {
slice, ok := raw.([]any)
if !ok {
return nil, fmt.Errorf("messages is not an array")
}
var messages []map[string]string
if len(slice) == 0 {
return nil, nil
}
messages := make([]map[string]string, 0, len(slice))

for _, m := range slice {
msg, ok := m.(map[string]any)
if !ok {
continue
}
role, _ := msg["role"].(string)
if role == "system" {
Copy link
Copy Markdown
Contributor

@nirrozenbaum nirrozenbaum May 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be correct to just return “slice” without copying?
in current code we copy the whole messages array on every request.
so if we have a long multi turn chat, memory allocation for each turn is expensive.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added a pre-allocation and an early return for empty slices to avoid unnecessary allocations.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my question was actually about the messages themselves, not only about the case of empty slice.
my intention was - if we have something along these lines:

func extractOpenAIMessages(raw any) ([]map[string]string, error) {
	slice, ok := raw.([]any)
	if !ok {
		return nil, fmt.Errorf("messages is not an array")
	}

    if len(slice) == 0 {
		return nil, nil
	}

    return slice

would it work?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should work with some extra changes, but returning the slice directly means forwarding all original message fields to NeMo (e.g., tool_calls, function_call) instead of only role and content. NeMo ignores extra fields, but it sends more data over the wire for no reason. It also requires changing the return type from []map[string]string to []any and updating all callers and tests accordingly. If we want to go in that direction, it would be better as a follow-up PR.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. I'm just thinking about the tradeoffs.
do we prefer allocating memory for every request and copying the whole chat? (that could be very large) or do we prefer sending the same request that we received over the wire to yet another http hop? (but at least we don't do both - allocate memory + sending http).

there are pros/cons for both sides.
my gut feeling says sending the original request as is would probably be "cheaper".
we can handle in a follow up, for sure.

continue
}
content, _ := msg["content"].(string)
messages = append(messages, map[string]string{"role": role, "content": content})
}
Expand Down
25 changes: 7 additions & 18 deletions pkg/plugins/nemo/request_guard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,9 @@ func TestNemoRequestGuardSendsCorrectPayloadMCP(t *testing.T) {
assert.Equal(t, "hello world", msg["content"])
}

// TestNemoRequestGuardForwardsAllNonSystemMessages verifies that user, assistant, and tool
// messages are forwarded to NeMo while system messages are filtered out (the /v1/guardrail/checks
// endpoint rejects the system role with status "error").
func TestNemoRequestGuardForwardsAllNonSystemMessages(t *testing.T) {
// TestNemoRequestGuardForwardsAllMessages verifies that all messages including system
// are forwarded to NeMo for evaluation.
func TestNemoRequestGuardForwardsAllMessages(t *testing.T) {
var capturedReq map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NoError(t, json.NewDecoder(r.Body).Decode(&capturedReq))
Expand All @@ -390,13 +389,13 @@ func TestNemoRequestGuardForwardsAllNonSystemMessages(t *testing.T) {

messages, ok := capturedReq["messages"].([]any)
require.True(t, ok, "messages should be an array")
require.Len(t, messages, 4, "system message must be filtered, 4 remaining forwarded")
require.Len(t, messages, 5, "all messages including system should be forwarded")

roles := make([]string, len(messages))
for i, m := range messages {
roles[i] = m.(map[string]any)["role"].(string)
}
assert.Equal(t, []string{"user", "assistant", "tool", "user"}, roles)
assert.Equal(t, []string{"system", "user", "assistant", "tool", "user"}, roles)
}

// TestNemoRequestGuardBaseURLTrailingSlash ensures a trailing slash in baseURL doesn't double up.
Expand Down Expand Up @@ -479,26 +478,16 @@ func TestExtractMessages(t *testing.T) {
},
},
{
name: "system message filtered out — /v1/guardrail/checks rejects system role",
name: "single system message",
body: map[string]any{
"messages": []any{
map[string]any{"role": "system", "content": "You are helpful"},
map[string]any{"role": "user", "content": "Hello"},
},
},
want: []map[string]string{
{"role": "user", "content": "Hello"},
{"role": "system", "content": "You are helpful"},
},
},
{
name: "system-only conversation — all filtered, returns nil",
body: map[string]any{
"messages": []any{
map[string]any{"role": "system", "content": "You are helpful"},
},
},
want: nil,
},
{
name: "tool message included — not filtered out",
body: map[string]any{
Expand Down
Loading