diff --git a/internal/common/task_utils.go b/internal/common/task_utils.go index b89a8d5..5f5b84e 100644 --- a/internal/common/task_utils.go +++ b/internal/common/task_utils.go @@ -58,6 +58,11 @@ func AugmentArgsForTask(task *types.Task, args []string, opts TaskAugmentOptions args = append(args, "--no-computer-use") } } + + if task.AgentConfigSnapshot.UseAwsBedrockInference != nil && + *task.AgentConfigSnapshot.UseAwsBedrockInference { + args = append(args, "--use-aws-bedrock-inference") + } } if task.AgentConfigSnapshot != nil && task.AgentConfigSnapshot.EnvironmentID != nil { diff --git a/internal/common/task_utils_test.go b/internal/common/task_utils_test.go index 3348a66..7771669 100644 --- a/internal/common/task_utils_test.go +++ b/internal/common/task_utils_test.go @@ -9,6 +9,7 @@ import ( func strPtr(v string) *string { return &v } func intPtr(v int) *int { return &v } +func boolPtr(v bool) *bool { return &v } func TestAugmentArgsForTask_IdleOnCompletePrecedence(t *testing.T) { baseArgs := []string{"agent", "run"} @@ -66,6 +67,24 @@ func TestAugmentArgsForTask_IdleOnCompletePrecedence(t *testing.T) { opts: TaskAugmentOptions{}, expected: []string{"agent", "run", "--model", "claude-sonnet-4", "--idle-on-complete", "12m"}, }, + { + name: "passes the Bedrock inference flag when enabled", + task: &types.Task{ + AgentConfigSnapshot: &types.AmbientAgentConfig{ + ModelID: strPtr("claude-sonnet-4"), + UseAwsBedrockInference: boolPtr(true), + }, + }, + opts: TaskAugmentOptions{}, + expected: []string{ + "agent", + "run", + "--model", + "claude-sonnet-4", + "--use-aws-bedrock-inference", + "--idle-on-complete", + }, + }, } for _, tt := range tests { diff --git a/internal/types/messages.go b/internal/types/messages.go index e11aa9d..88fca30 100644 --- a/internal/types/messages.go +++ b/internal/types/messages.go @@ -67,14 +67,15 @@ type TaskDefinition struct { // AmbientAgentConfig represents the agent configuration type AmbientAgentConfig struct { - EnvironmentID *string `json:"environment_id,omitempty"` - BasePrompt *string `json:"base_prompt,omitempty"` - ModelID *string `json:"model_id,omitempty"` - ProfileID *string `json:"profile_id,omitempty"` - SkillSpec *string `json:"skill_spec,omitempty"` - MCPServers map[string]json.RawMessage `json:"mcp_servers,omitempty"` - ComputerUseEnabled *bool `json:"computer_use_enabled,omitempty"` - IdleTimeoutMinutes *int `json:"idle_timeout_minutes,omitempty"` + EnvironmentID *string `json:"environment_id,omitempty"` + BasePrompt *string `json:"base_prompt,omitempty"` + ModelID *string `json:"model_id,omitempty"` + ProfileID *string `json:"profile_id,omitempty"` + SkillSpec *string `json:"skill_spec,omitempty"` + MCPServers map[string]json.RawMessage `json:"mcp_servers,omitempty"` + ComputerUseEnabled *bool `json:"computer_use_enabled,omitempty"` + UseAwsBedrockInference *bool `json:"use_aws_bedrock_inference,omitempty"` + IdleTimeoutMinutes *int `json:"idle_timeout_minutes,omitempty"` } // Task represents an ambient agent job.