Skip to content

Commit 1066276

Browse files
authored
fix: custom header support for invoke and dev commands (#652)
* fix: custom header support for invoke and dev commands Three fixes for custom header handling: 1. Allow empty requestHeaderAllowlist in add-agent flows: the TextInput for header allowlist now accepts empty input (allowEmpty) so users can press Enter to skip, matching the prompt text. 2. Add -H/--header flag to invoke command: repeatable flag that accepts "Name: Value" pairs, normalizes header names with the AgentCore prefix, and injects them via SDK middleware for deployed invocations. Works for HTTP, MCP, and A2A protocols in both CLI and TUI modes. 3. Add -H/--header flag to dev command: same flag format, threaded through all local dev invoke paths (streaming, non-streaming, MCP tool calls, A2A) in both non-interactive and TUI modes. * style: fix prettier formatting in 7 files
1 parent acd300d commit 1066276

File tree

17 files changed

+280
-56
lines changed

17 files changed

+280
-56
lines changed

src/cli/aws/agentcore.ts

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,34 @@ import {
66
InvokeAgentRuntimeCommand,
77
StopRuntimeSessionCommand,
88
} from '@aws-sdk/client-bedrock-agentcore';
9+
import type { HttpRequest } from '@smithy/protocol-http';
910
import type { DocumentType } from '@smithy/types';
1011

12+
/**
13+
* Create a BedrockAgentCoreClient with optional custom header injection middleware.
14+
*/
15+
function createAgentCoreClient(region: string, headers?: Record<string, string>): BedrockAgentCoreClient {
16+
const client = new BedrockAgentCoreClient({
17+
region,
18+
credentials: getCredentialProvider(),
19+
});
20+
21+
if (headers && Object.keys(headers).length > 0) {
22+
client.middlewareStack.add(
23+
next => async args => {
24+
const request = args.request as HttpRequest;
25+
for (const [name, value] of Object.entries(headers)) {
26+
request.headers[name] = value;
27+
}
28+
return next(args);
29+
},
30+
{ step: 'build', name: 'addCustomHeaders' }
31+
);
32+
}
33+
34+
return client;
35+
}
36+
1137
/** Logger interface for SSE events */
1238
export interface SSELogger {
1339
logSSEEvent(rawLine: string): void;
@@ -25,6 +51,8 @@ export interface InvokeAgentRuntimeOptions {
2551
userId?: string;
2652
/** Optional logger for SSE event debugging */
2753
logger?: SSELogger;
54+
/** Custom headers to forward to the agent runtime */
55+
headers?: Record<string, string>;
2856
}
2957

3058
export interface InvokeAgentRuntimeResult {
@@ -109,10 +137,7 @@ export function extractResult(text: string): string {
109137
* Returns an object with the stream generator and session ID.
110138
*/
111139
export async function invokeAgentRuntimeStreaming(options: InvokeAgentRuntimeOptions): Promise<StreamingInvokeResult> {
112-
const client = new BedrockAgentCoreClient({
113-
region: options.region,
114-
credentials: getCredentialProvider(),
115-
});
140+
const client = createAgentCoreClient(options.region, options.headers);
116141

117142
const command = new InvokeAgentRuntimeCommand({
118143
agentRuntimeArn: options.runtimeArn,
@@ -205,10 +230,7 @@ export async function invokeAgentRuntimeStreaming(options: InvokeAgentRuntimeOpt
205230
* Invoke an AgentCore Runtime and return the response.
206231
*/
207232
export async function invokeAgentRuntime(options: InvokeAgentRuntimeOptions): Promise<InvokeAgentRuntimeResult> {
208-
const client = new BedrockAgentCoreClient({
209-
region: options.region,
210-
credentials: getCredentialProvider(),
211-
});
233+
const client = createAgentCoreClient(options.region, options.headers);
212234

213235
const command = new InvokeAgentRuntimeCommand({
214236
agentRuntimeArn: options.runtimeArn,
@@ -349,6 +371,8 @@ export interface McpInvokeOptions {
349371
userId?: string;
350372
mcpSessionId?: string;
351373
logger?: SSELogger;
374+
/** Custom headers to forward to the agent runtime */
375+
headers?: Record<string, string>;
352376
}
353377

354378
export interface McpToolDef {
@@ -372,10 +396,7 @@ interface McpRpcResult {
372396

373397
/** Send a JSON-RPC payload through InvokeAgentRuntime and return the parsed response. */
374398
async function mcpRpcCall(options: McpInvokeOptions, body: Record<string, unknown>): Promise<McpRpcResult> {
375-
const client = new BedrockAgentCoreClient({
376-
region: options.region,
377-
credentials: getCredentialProvider(),
378-
});
399+
const client = createAgentCoreClient(options.region, options.headers);
379400

380401
options.logger?.logSSEEvent(`MCP request: ${JSON.stringify(body)}`);
381402

@@ -420,10 +441,7 @@ async function mcpRpcCallStrict(options: McpInvokeOptions, body: Record<string,
420441

421442
/** Send a JSON-RPC notification (no id, no response expected). */
422443
async function mcpRpcNotify(options: McpInvokeOptions, body: Record<string, unknown>): Promise<void> {
423-
const client = new BedrockAgentCoreClient({
424-
region: options.region,
425-
credentials: getCredentialProvider(),
426-
});
444+
const client = createAgentCoreClient(options.region, options.headers);
427445

428446
const command = new InvokeAgentRuntimeCommand({
429447
agentRuntimeArn: options.runtimeArn,
@@ -592,6 +610,8 @@ export interface A2AInvokeOptions {
592610
runtimeArn: string;
593611
userId?: string;
594612
logger?: SSELogger;
613+
/** Custom headers to forward to the agent runtime */
614+
headers?: Record<string, string>;
595615
}
596616

597617
let a2aRequestId = 1;
@@ -601,10 +621,7 @@ let a2aRequestId = 1;
601621
* Streams text parts from the response artifacts.
602622
*/
603623
export async function invokeA2ARuntime(options: A2AInvokeOptions, message: string): Promise<StreamingInvokeResult> {
604-
const client = new BedrockAgentCoreClient({
605-
region: options.region,
606-
credentials: getCredentialProvider(),
607-
});
624+
const client = createAgentCoreClient(options.region, options.headers);
608625

609626
const body = {
610627
jsonrpc: '2.0',

src/cli/commands/dev/command.tsx

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import { FatalError } from '../../tui/components';
2020
import { LayoutProvider } from '../../tui/context';
2121
import { COMMAND_DESCRIPTIONS } from '../../tui/copy';
2222
import { requireProject } from '../../tui/guards';
23+
import { parseHeaderFlags } from '../shared/header-utils';
2324
import type { Command } from '@commander-js/extra-typings';
2425
import { Text, render } from 'ink';
2526
import React from 'react';
@@ -29,16 +30,21 @@ const ENTER_ALT_SCREEN = '\x1B[?1049h\x1B[H';
2930
const EXIT_ALT_SCREEN = '\x1B[?1049l';
3031
const SHOW_CURSOR = '\x1B[?25h';
3132

32-
async function invokeDevServer(port: number, prompt: string, stream: boolean): Promise<void> {
33+
async function invokeDevServer(
34+
port: number,
35+
prompt: string,
36+
stream: boolean,
37+
headers?: Record<string, string>
38+
): Promise<void> {
3339
try {
3440
if (stream) {
3541
// Stream response to stdout
36-
for await (const chunk of invokeAgentStreaming(port, prompt)) {
42+
for await (const chunk of invokeAgentStreaming({ port, message: prompt, headers })) {
3743
process.stdout.write(chunk);
3844
}
3945
process.stdout.write('\n');
4046
} else {
41-
const response = await invokeAgent(port, prompt);
47+
const response = await invokeAgent({ port, message: prompt, headers });
4248
console.log(response);
4349
}
4450
} catch (err) {
@@ -52,9 +58,9 @@ async function invokeDevServer(port: number, prompt: string, stream: boolean): P
5258
}
5359
}
5460

55-
async function invokeA2ADevServer(port: number, prompt: string): Promise<void> {
61+
async function invokeA2ADevServer(port: number, prompt: string, headers?: Record<string, string>): Promise<void> {
5662
try {
57-
for await (const chunk of invokeForProtocol('A2A', { port, message: prompt })) {
63+
for await (const chunk of invokeForProtocol('A2A', { port, message: prompt, headers })) {
5864
process.stdout.write(chunk);
5965
}
6066
process.stdout.write('\n');
@@ -69,10 +75,16 @@ async function invokeA2ADevServer(port: number, prompt: string): Promise<void> {
6975
}
7076
}
7177

72-
async function handleMcpInvoke(port: number, invokeValue: string, toolName?: string, input?: string): Promise<void> {
78+
async function handleMcpInvoke(
79+
port: number,
80+
invokeValue: string,
81+
toolName?: string,
82+
input?: string,
83+
headers?: Record<string, string>
84+
): Promise<void> {
7385
try {
7486
if (invokeValue === 'list-tools') {
75-
const { tools } = await listMcpTools(port);
87+
const { tools } = await listMcpTools(port, undefined, headers);
7688
if (tools.length === 0) {
7789
console.log('No tools available.');
7890
return;
@@ -89,7 +101,7 @@ async function handleMcpInvoke(port: number, invokeValue: string, toolName?: str
89101
process.exit(1);
90102
}
91103
// Initialize session first, then call tool with the session ID
92-
const { sessionId } = await listMcpTools(port);
104+
const { sessionId } = await listMcpTools(port, undefined, headers);
93105
let args: Record<string, unknown> = {};
94106
if (input) {
95107
try {
@@ -100,7 +112,7 @@ async function handleMcpInvoke(port: number, invokeValue: string, toolName?: str
100112
process.exit(1);
101113
}
102114
}
103-
const result = await callMcpTool(port, toolName, args, sessionId);
115+
const result = await callMcpTool(port, toolName, args, sessionId, undefined, headers);
104116
console.log(result);
105117
} else {
106118
console.error(`Error: Unknown MCP invoke command "${invokeValue}"`);
@@ -132,10 +144,22 @@ export const registerDev = (program: Command) => {
132144
.option('-l, --logs', 'Run dev server with logs to stdout [non-interactive]')
133145
.option('--tool <name>', 'MCP tool name (used with --invoke call-tool)')
134146
.option('--input <json>', 'MCP tool arguments as JSON (used with --invoke call-tool)')
147+
.option(
148+
'-H, --header <header>',
149+
'Custom header to forward to the agent (format: "Name: Value", repeatable)',
150+
(val: string, prev: string[]) => [...prev, val],
151+
[] as string[]
152+
)
135153
.action(async opts => {
136154
try {
137155
const port = parseInt(opts.port, 10);
138156

157+
// Parse custom headers
158+
let headers: Record<string, string> | undefined;
159+
if (opts.header && opts.header.length > 0) {
160+
headers = parseHeaderFlags(opts.header);
161+
}
162+
139163
// If --invoke provided, call the dev server and exit
140164
if (opts.invoke) {
141165
const invokeProject = await loadProjectConfig(getWorkingDirectory());
@@ -166,11 +190,11 @@ export const registerDev = (program: Command) => {
166190

167191
// Protocol-aware dispatch
168192
if (protocol === 'MCP') {
169-
await handleMcpInvoke(invokePort, opts.invoke, opts.tool, opts.input);
193+
await handleMcpInvoke(invokePort, opts.invoke, opts.tool, opts.input, headers);
170194
} else if (protocol === 'A2A') {
171-
await invokeA2ADevServer(invokePort, opts.invoke);
195+
await invokeA2ADevServer(invokePort, opts.invoke, headers);
172196
} else {
173-
await invokeDevServer(invokePort, opts.invoke, opts.stream ?? false);
197+
await invokeDevServer(invokePort, opts.invoke, opts.stream ?? false, headers);
174198
}
175199
return;
176200
}
@@ -308,6 +332,7 @@ export const registerDev = (program: Command) => {
308332
workingDir={workingDir}
309333
port={port}
310334
agentName={opts.agent}
335+
headers={headers}
311336
/>
312337
</LayoutProvider>
313338
);

src/cli/commands/invoke/action.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ export async function handleInvoke(context: InvokeContext, options: InvokeOption
9595
region: targetConfig.region,
9696
runtimeArn: agentState.runtimeArn,
9797
userId: options.userId,
98+
headers: options.headers,
9899
};
99100

100101
// list-tools: list available MCP tools
@@ -167,7 +168,12 @@ export async function handleInvoke(context: InvokeContext, options: InvokeOption
167168
if (agentSpec.protocol === 'A2A') {
168169
try {
169170
const a2aResult = await invokeA2ARuntime(
170-
{ region: targetConfig.region, runtimeArn: agentState.runtimeArn, userId: options.userId },
171+
{
172+
region: targetConfig.region,
173+
runtimeArn: agentState.runtimeArn,
174+
userId: options.userId,
175+
headers: options.headers,
176+
},
171177
options.prompt
172178
);
173179
let response = '';
@@ -214,6 +220,7 @@ export async function handleInvoke(context: InvokeContext, options: InvokeOption
214220
sessionId: options.sessionId,
215221
userId: options.userId,
216222
logger, // Pass logger for SSE event debugging
223+
headers: options.headers,
217224
});
218225

219226
for await (const chunk of result.stream) {
@@ -245,6 +252,7 @@ export async function handleInvoke(context: InvokeContext, options: InvokeOption
245252
payload: options.prompt,
246253
sessionId: options.sessionId,
247254
userId: options.userId,
255+
headers: options.headers,
248256
});
249257

250258
logger.logResponse(response.content);

src/cli/commands/invoke/command.tsx

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { getErrorMessage } from '../../errors';
22
import { COMMAND_DESCRIPTIONS } from '../../tui/copy';
33
import { requireProject } from '../../tui/guards';
44
import { InvokeScreen } from '../../tui/screens/invoke';
5+
import { parseHeaderFlags } from '../shared/header-utils';
56
import { handleInvoke, loadInvokeConfig } from './action';
67
import type { InvokeOptions } from './types';
78
import { validateInvokeOptions } from './validate';
@@ -103,6 +104,12 @@ export const registerInvoke = (program: Command) => {
103104
.option('--stream', 'Stream response in real-time (TUI streams by default) [non-interactive]')
104105
.option('--tool <name>', 'MCP tool name (use with "call-tool" prompt) [non-interactive]')
105106
.option('--input <json>', 'MCP tool arguments as JSON (use with --tool) [non-interactive]')
107+
.option(
108+
'-H, --header <header>',
109+
'Custom header to forward to the agent (format: "Name: Value", repeatable)',
110+
(val: string, prev: string[]) => [...prev, val],
111+
[] as string[]
112+
)
106113
.action(
107114
async (
108115
positionalPrompt: string | undefined,
@@ -116,13 +123,20 @@ export const registerInvoke = (program: Command) => {
116123
stream?: boolean;
117124
tool?: string;
118125
input?: string;
126+
header?: string[];
119127
}
120128
) => {
121129
try {
122130
requireProject();
123131
// --prompt flag takes precedence over positional argument
124132
const prompt = cliOptions.prompt ?? positionalPrompt;
125133

134+
// Parse custom headers
135+
let headers: Record<string, string> | undefined;
136+
if (cliOptions.header && cliOptions.header.length > 0) {
137+
headers = parseHeaderFlags(cliOptions.header);
138+
}
139+
126140
// CLI mode if any CLI-specific options provided (follows deploy command pattern)
127141
if (
128142
prompt ||
@@ -142,15 +156,17 @@ export const registerInvoke = (program: Command) => {
142156
stream: cliOptions.stream,
143157
tool: cliOptions.tool,
144158
input: cliOptions.input,
159+
headers,
145160
});
146161
} else {
147-
// No CLI options - interactive TUI mode
162+
// No CLI options - interactive TUI mode (headers still passed if provided)
148163
const { waitUntilExit } = render(
149164
<InvokeScreen
150165
isInteractive={true}
151166
onExit={() => process.exit(0)}
152167
initialSessionId={cliOptions.sessionId}
153168
initialUserId={cliOptions.userId}
169+
initialHeaders={headers}
154170
/>
155171
);
156172
await waitUntilExit();

src/cli/commands/invoke/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ export interface InvokeOptions {
1010
tool?: string;
1111
/** MCP tool arguments as JSON string (used with --tool) */
1212
input?: string;
13+
/** Custom headers to forward to the agent runtime (key-value pairs) */
14+
headers?: Record<string, string>;
1315
}
1416

1517
export interface InvokeResult {

0 commit comments

Comments
 (0)