Skip to content

Commit 24ca2f5

Browse files
committed
fix: custom header support for invoke and dev commands
Three fixes for custom header handling: 1. Allow empty requestHeaderAllowlist in add-agent flows: the TextInput for header allowlist now accepts empty input (allowEmpty) so users can press Enter to skip, matching the prompt text. 2. Add -H/--header flag to invoke command: repeatable flag that accepts "Name: Value" pairs, normalizes header names with the AgentCore prefix, and injects them via SDK middleware for deployed invocations. Works for HTTP, MCP, and A2A protocols in both CLI and TUI modes. 3. Add -H/--header flag to dev command: same flag format, threaded through all local dev invoke paths (streaming, non-streaming, MCP tool calls, A2A) in both non-interactive and TUI modes.
1 parent 8a1af21 commit 24ca2f5

17 files changed

Lines changed: 227 additions & 56 deletions

File tree

src/cli/aws/agentcore.ts

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,34 @@ import {
66
InvokeAgentRuntimeCommand,
77
StopRuntimeSessionCommand,
88
} from '@aws-sdk/client-bedrock-agentcore';
9+
import type { HttpRequest } from '@smithy/protocol-http';
910
import type { DocumentType } from '@smithy/types';
1011

12+
/**
13+
* Create a BedrockAgentCoreClient with optional custom header injection middleware.
14+
*/
15+
function createAgentCoreClient(region: string, headers?: Record<string, string>): BedrockAgentCoreClient {
16+
const client = new BedrockAgentCoreClient({
17+
region,
18+
credentials: getCredentialProvider(),
19+
});
20+
21+
if (headers && Object.keys(headers).length > 0) {
22+
client.middlewareStack.add(
23+
next => async args => {
24+
const request = args.request as HttpRequest;
25+
for (const [name, value] of Object.entries(headers)) {
26+
request.headers[name] = value;
27+
}
28+
return next(args);
29+
},
30+
{ step: 'build', name: 'addCustomHeaders' }
31+
);
32+
}
33+
34+
return client;
35+
}
36+
1137
/** Logger interface for SSE events */
1238
export interface SSELogger {
1339
logSSEEvent(rawLine: string): void;
@@ -25,6 +51,8 @@ export interface InvokeAgentRuntimeOptions {
2551
userId?: string;
2652
/** Optional logger for SSE event debugging */
2753
logger?: SSELogger;
54+
/** Custom headers to forward to the agent runtime */
55+
headers?: Record<string, string>;
2856
}
2957

3058
export interface InvokeAgentRuntimeResult {
@@ -109,10 +137,7 @@ export function extractResult(text: string): string {
109137
* Returns an object with the stream generator and session ID.
110138
*/
111139
export async function invokeAgentRuntimeStreaming(options: InvokeAgentRuntimeOptions): Promise<StreamingInvokeResult> {
112-
const client = new BedrockAgentCoreClient({
113-
region: options.region,
114-
credentials: getCredentialProvider(),
115-
});
140+
const client = createAgentCoreClient(options.region, options.headers);
116141

117142
const command = new InvokeAgentRuntimeCommand({
118143
agentRuntimeArn: options.runtimeArn,
@@ -205,10 +230,7 @@ export async function invokeAgentRuntimeStreaming(options: InvokeAgentRuntimeOpt
205230
* Invoke an AgentCore Runtime and return the response.
206231
*/
207232
export async function invokeAgentRuntime(options: InvokeAgentRuntimeOptions): Promise<InvokeAgentRuntimeResult> {
208-
const client = new BedrockAgentCoreClient({
209-
region: options.region,
210-
credentials: getCredentialProvider(),
211-
});
233+
const client = createAgentCoreClient(options.region, options.headers);
212234

213235
const command = new InvokeAgentRuntimeCommand({
214236
agentRuntimeArn: options.runtimeArn,
@@ -349,6 +371,8 @@ export interface McpInvokeOptions {
349371
userId?: string;
350372
mcpSessionId?: string;
351373
logger?: SSELogger;
374+
/** Custom headers to forward to the agent runtime */
375+
headers?: Record<string, string>;
352376
}
353377

354378
export interface McpToolDef {
@@ -372,10 +396,7 @@ interface McpRpcResult {
372396

373397
/** Send a JSON-RPC payload through InvokeAgentRuntime and return the parsed response. */
374398
async function mcpRpcCall(options: McpInvokeOptions, body: Record<string, unknown>): Promise<McpRpcResult> {
375-
const client = new BedrockAgentCoreClient({
376-
region: options.region,
377-
credentials: getCredentialProvider(),
378-
});
399+
const client = createAgentCoreClient(options.region, options.headers);
379400

380401
options.logger?.logSSEEvent(`MCP request: ${JSON.stringify(body)}`);
381402

@@ -420,10 +441,7 @@ async function mcpRpcCallStrict(options: McpInvokeOptions, body: Record<string,
420441

421442
/** Send a JSON-RPC notification (no id, no response expected). */
422443
async function mcpRpcNotify(options: McpInvokeOptions, body: Record<string, unknown>): Promise<void> {
423-
const client = new BedrockAgentCoreClient({
424-
region: options.region,
425-
credentials: getCredentialProvider(),
426-
});
444+
const client = createAgentCoreClient(options.region, options.headers);
427445

428446
const command = new InvokeAgentRuntimeCommand({
429447
agentRuntimeArn: options.runtimeArn,
@@ -592,6 +610,8 @@ export interface A2AInvokeOptions {
592610
runtimeArn: string;
593611
userId?: string;
594612
logger?: SSELogger;
613+
/** Custom headers to forward to the agent runtime */
614+
headers?: Record<string, string>;
595615
}
596616

597617
let a2aRequestId = 1;
@@ -601,10 +621,7 @@ let a2aRequestId = 1;
601621
* Streams text parts from the response artifacts.
602622
*/
603623
export async function invokeA2ARuntime(options: A2AInvokeOptions, message: string): Promise<StreamingInvokeResult> {
604-
const client = new BedrockAgentCoreClient({
605-
region: options.region,
606-
credentials: getCredentialProvider(),
607-
});
624+
const client = createAgentCoreClient(options.region, options.headers);
608625

609626
const body = {
610627
jsonrpc: '2.0',

src/cli/commands/dev/command.tsx

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { findConfigRoot, getWorkingDirectory, readEnvFile } from '../../../lib';
2+
import { parseHeaderFlags } from '../shared/header-utils';
23
import { getErrorMessage } from '../../errors';
34
import { ExecLogger } from '../../logging';
45
import {
@@ -29,16 +30,16 @@ const ENTER_ALT_SCREEN = '\x1B[?1049h\x1B[H';
2930
const EXIT_ALT_SCREEN = '\x1B[?1049l';
3031
const SHOW_CURSOR = '\x1B[?25h';
3132

32-
async function invokeDevServer(port: number, prompt: string, stream: boolean): Promise<void> {
33+
async function invokeDevServer(port: number, prompt: string, stream: boolean, headers?: Record<string, string>): Promise<void> {
3334
try {
3435
if (stream) {
3536
// Stream response to stdout
36-
for await (const chunk of invokeAgentStreaming(port, prompt)) {
37+
for await (const chunk of invokeAgentStreaming({ port, message: prompt, headers })) {
3738
process.stdout.write(chunk);
3839
}
3940
process.stdout.write('\n');
4041
} else {
41-
const response = await invokeAgent(port, prompt);
42+
const response = await invokeAgent({ port, message: prompt, headers });
4243
console.log(response);
4344
}
4445
} catch (err) {
@@ -52,9 +53,9 @@ async function invokeDevServer(port: number, prompt: string, stream: boolean): P
5253
}
5354
}
5455

55-
async function invokeA2ADevServer(port: number, prompt: string): Promise<void> {
56+
async function invokeA2ADevServer(port: number, prompt: string, headers?: Record<string, string>): Promise<void> {
5657
try {
57-
for await (const chunk of invokeForProtocol('A2A', { port, message: prompt })) {
58+
for await (const chunk of invokeForProtocol('A2A', { port, message: prompt, headers })) {
5859
process.stdout.write(chunk);
5960
}
6061
process.stdout.write('\n');
@@ -69,10 +70,10 @@ async function invokeA2ADevServer(port: number, prompt: string): Promise<void> {
6970
}
7071
}
7172

72-
async function handleMcpInvoke(port: number, invokeValue: string, toolName?: string, input?: string): Promise<void> {
73+
async function handleMcpInvoke(port: number, invokeValue: string, toolName?: string, input?: string, headers?: Record<string, string>): Promise<void> {
7374
try {
7475
if (invokeValue === 'list-tools') {
75-
const { tools } = await listMcpTools(port);
76+
const { tools } = await listMcpTools(port, undefined, headers);
7677
if (tools.length === 0) {
7778
console.log('No tools available.');
7879
return;
@@ -89,7 +90,7 @@ async function handleMcpInvoke(port: number, invokeValue: string, toolName?: str
8990
process.exit(1);
9091
}
9192
// Initialize session first, then call tool with the session ID
92-
const { sessionId } = await listMcpTools(port);
93+
const { sessionId } = await listMcpTools(port, undefined, headers);
9394
let args: Record<string, unknown> = {};
9495
if (input) {
9596
try {
@@ -100,7 +101,7 @@ async function handleMcpInvoke(port: number, invokeValue: string, toolName?: str
100101
process.exit(1);
101102
}
102103
}
103-
const result = await callMcpTool(port, toolName, args, sessionId);
104+
const result = await callMcpTool(port, toolName, args, sessionId, undefined, headers);
104105
console.log(result);
105106
} else {
106107
console.error(`Error: Unknown MCP invoke command "${invokeValue}"`);
@@ -132,10 +133,17 @@ export const registerDev = (program: Command) => {
132133
.option('-l, --logs', 'Run dev server with logs to stdout [non-interactive]')
133134
.option('--tool <name>', 'MCP tool name (used with --invoke call-tool)')
134135
.option('--input <json>', 'MCP tool arguments as JSON (used with --invoke call-tool)')
136+
.option('-H, --header <header>', 'Custom header to forward to the agent (format: "Name: Value", repeatable)', (val: string, prev: string[]) => [...prev, val], [] as string[])
135137
.action(async opts => {
136138
try {
137139
const port = parseInt(opts.port, 10);
138140

141+
// Parse custom headers
142+
let headers: Record<string, string> | undefined;
143+
if (opts.header && opts.header.length > 0) {
144+
headers = parseHeaderFlags(opts.header);
145+
}
146+
139147
// If --invoke provided, call the dev server and exit
140148
if (opts.invoke) {
141149
const invokeProject = await loadProjectConfig(getWorkingDirectory());
@@ -166,11 +174,11 @@ export const registerDev = (program: Command) => {
166174

167175
// Protocol-aware dispatch
168176
if (protocol === 'MCP') {
169-
await handleMcpInvoke(invokePort, opts.invoke, opts.tool, opts.input);
177+
await handleMcpInvoke(invokePort, opts.invoke, opts.tool, opts.input, headers);
170178
} else if (protocol === 'A2A') {
171-
await invokeA2ADevServer(invokePort, opts.invoke);
179+
await invokeA2ADevServer(invokePort, opts.invoke, headers);
172180
} else {
173-
await invokeDevServer(invokePort, opts.invoke, opts.stream ?? false);
181+
await invokeDevServer(invokePort, opts.invoke, opts.stream ?? false, headers);
174182
}
175183
return;
176184
}
@@ -308,6 +316,7 @@ export const registerDev = (program: Command) => {
308316
workingDir={workingDir}
309317
port={port}
310318
agentName={opts.agent}
319+
headers={headers}
311320
/>
312321
</LayoutProvider>
313322
);

src/cli/commands/invoke/action.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ export async function handleInvoke(context: InvokeContext, options: InvokeOption
9595
region: targetConfig.region,
9696
runtimeArn: agentState.runtimeArn,
9797
userId: options.userId,
98+
headers: options.headers,
9899
};
99100

100101
// list-tools: list available MCP tools
@@ -167,7 +168,7 @@ export async function handleInvoke(context: InvokeContext, options: InvokeOption
167168
if (agentSpec.protocol === 'A2A') {
168169
try {
169170
const a2aResult = await invokeA2ARuntime(
170-
{ region: targetConfig.region, runtimeArn: agentState.runtimeArn, userId: options.userId },
171+
{ region: targetConfig.region, runtimeArn: agentState.runtimeArn, userId: options.userId, headers: options.headers },
171172
options.prompt
172173
);
173174
let response = '';
@@ -214,6 +215,7 @@ export async function handleInvoke(context: InvokeContext, options: InvokeOption
214215
sessionId: options.sessionId,
215216
userId: options.userId,
216217
logger, // Pass logger for SSE event debugging
218+
headers: options.headers,
217219
});
218220

219221
for await (const chunk of result.stream) {
@@ -245,6 +247,7 @@ export async function handleInvoke(context: InvokeContext, options: InvokeOption
245247
payload: options.prompt,
246248
sessionId: options.sessionId,
247249
userId: options.userId,
250+
headers: options.headers,
248251
});
249252

250253
logger.logResponse(response.content);

src/cli/commands/invoke/command.tsx

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { getErrorMessage } from '../../errors';
22
import { COMMAND_DESCRIPTIONS } from '../../tui/copy';
33
import { requireProject } from '../../tui/guards';
44
import { InvokeScreen } from '../../tui/screens/invoke';
5+
import { parseHeaderFlags } from '../shared/header-utils';
56
import { handleInvoke, loadInvokeConfig } from './action';
67
import type { InvokeOptions } from './types';
78
import { validateInvokeOptions } from './validate';
@@ -103,6 +104,7 @@ export const registerInvoke = (program: Command) => {
103104
.option('--stream', 'Stream response in real-time (TUI streams by default) [non-interactive]')
104105
.option('--tool <name>', 'MCP tool name (use with "call-tool" prompt) [non-interactive]')
105106
.option('--input <json>', 'MCP tool arguments as JSON (use with --tool) [non-interactive]')
107+
.option('-H, --header <header>', 'Custom header to forward to the agent (format: "Name: Value", repeatable)', (val: string, prev: string[]) => [...prev, val], [] as string[])
106108
.action(
107109
async (
108110
positionalPrompt: string | undefined,
@@ -116,13 +118,20 @@ export const registerInvoke = (program: Command) => {
116118
stream?: boolean;
117119
tool?: string;
118120
input?: string;
121+
header?: string[];
119122
}
120123
) => {
121124
try {
122125
requireProject();
123126
// --prompt flag takes precedence over positional argument
124127
const prompt = cliOptions.prompt ?? positionalPrompt;
125128

129+
// Parse custom headers
130+
let headers: Record<string, string> | undefined;
131+
if (cliOptions.header && cliOptions.header.length > 0) {
132+
headers = parseHeaderFlags(cliOptions.header);
133+
}
134+
126135
// CLI mode if any CLI-specific options provided (follows deploy command pattern)
127136
if (
128137
prompt ||
@@ -142,15 +151,17 @@ export const registerInvoke = (program: Command) => {
142151
stream: cliOptions.stream,
143152
tool: cliOptions.tool,
144153
input: cliOptions.input,
154+
headers,
145155
});
146156
} else {
147-
// No CLI options - interactive TUI mode
157+
// No CLI options - interactive TUI mode (headers still passed if provided)
148158
const { waitUntilExit } = render(
149159
<InvokeScreen
150160
isInteractive={true}
151161
onExit={() => process.exit(0)}
152162
initialSessionId={cliOptions.sessionId}
153163
initialUserId={cliOptions.userId}
164+
initialHeaders={headers}
154165
/>
155166
);
156167
await waitUntilExit();

src/cli/commands/invoke/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ export interface InvokeOptions {
1010
tool?: string;
1111
/** MCP tool arguments as JSON string (used with --tool) */
1212
input?: string;
13+
/** Custom headers to forward to the agent runtime (key-value pairs) */
14+
headers?: Record<string, string>;
1315
}
1416

1517
export interface InvokeResult {

0 commit comments

Comments
 (0)