Skip to content

Commit 6aa4b9b

Browse files
committed
Add assumerole
1 parent eab5398 commit 6aa4b9b

8 files changed

Lines changed: 215 additions & 2 deletions

File tree

deploy/docker/docker-compose.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ services:
193193
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
194194
- AWS_SESSION_TOKEN=${AWS_SESSION_TOKEN:-}
195195
- AWS_REGION=${AWS_REGION:-}
196+
- AWS_ASSUME_ROLE_ARN=${AWS_ASSUME_ROLE_ARN:-}
197+
- AWS_EXTERNAL_ID=${AWS_EXTERNAL_ID:-}
198+
- AWS_ROLE_SESSION_NAME=${AWS_ROLE_SESSION_NAME:-crapi-chatbot-session}
196199
- GOOGLE_APPLICATION_CREDENTIALS=${GOOGLE_APPLICATION_CREDENTIALS:-}
197200
- VERTEX_PROJECT=${VERTEX_PROJECT:-}
198201
- VERTEX_LOCATION=${VERTEX_LOCATION:-}

deploy/helm/templates/chatbot/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ data:
4141
AWS_SECRET_ACCESS_KEY: {{ .Values.awsSecretAccessKey | quote }}
4242
AWS_SESSION_TOKEN: {{ .Values.awsSessionToken | quote }}
4343
AWS_REGION: {{ .Values.awsRegion | quote }}
44+
AWS_ASSUME_ROLE_ARN: {{ .Values.awsAssumeRoleArn | quote }}
45+
AWS_EXTERNAL_ID: {{ .Values.awsExternalId | quote }}
46+
AWS_ROLE_SESSION_NAME: {{ .Values.awsRoleSessionName | quote }}
4447
GOOGLE_APPLICATION_CREDENTIALS: {{ .Values.googleApplicationCredentials | quote }}
4548
VERTEX_PROJECT: {{ .Values.vertexProject | quote }}
4649
VERTEX_LOCATION: {{ .Values.vertexLocation | quote }}

deploy/helm/values.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ awsAccessKeyId: ""
5252
awsSecretAccessKey: ""
5353
awsSessionToken: ""
5454
awsRegion: ""
55+
# AWS IAM Assume Role configuration (optional - use instead of static credentials)
56+
awsAssumeRoleArn: ""
57+
awsExternalId: ""
58+
awsRoleSessionName: "crapi-chatbot-session"
5559

5660
# Google Vertex AI configuration
5761
googleApplicationCredentials: ""

deploy/k8s/base/chatbot/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ data:
4040
AWS_SECRET_ACCESS_KEY: ""
4141
AWS_SESSION_TOKEN: ""
4242
AWS_REGION: ""
43+
AWS_ASSUME_ROLE_ARN: ""
44+
AWS_EXTERNAL_ID: ""
45+
AWS_ROLE_SESSION_NAME: "crapi-chatbot-session"
4346
GOOGLE_APPLICATION_CREDENTIALS: ""
4447
VERTEX_PROJECT: ""
4548
VERTEX_LOCATION: ""
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""AWS credentials helper with STS assume role support."""
2+
3+
import logging
4+
import os
5+
import threading
6+
import time
7+
from typing import Optional
8+
9+
import boto3
10+
from botocore.credentials import RefreshableCredentials
11+
from botocore.session import get_session
12+
13+
from .config import Config
14+
15+
logger = logging.getLogger(__name__)
16+
17+
# Cache for assumed role credentials
18+
_credentials_cache = {
19+
"credentials": None,
20+
"expiration": 0,
21+
"lock": threading.Lock(),
22+
}
23+
24+
# Refresh credentials 5 minutes before expiration
25+
CREDENTIALS_REFRESH_BUFFER_SECONDS = 300
26+
27+
28+
def _get_base_session():
29+
"""Get a boto3 session with base credentials (from env vars or instance profile)."""
30+
return boto3.Session(
31+
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
32+
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
33+
aws_session_token=os.getenv("AWS_SESSION_TOKEN"),
34+
region_name=os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION"),
35+
)
36+
37+
38+
def _assume_role() -> dict:
39+
"""Assume the configured IAM role and return temporary credentials."""
40+
role_arn = Config.AWS_ASSUME_ROLE_ARN
41+
external_id = Config.AWS_EXTERNAL_ID
42+
session_name = Config.AWS_ROLE_SESSION_NAME
43+
44+
logger.info("Assuming IAM role: %s", role_arn)
45+
46+
base_session = _get_base_session()
47+
sts_client = base_session.client("sts")
48+
49+
assume_role_kwargs = {
50+
"RoleArn": role_arn,
51+
"RoleSessionName": session_name,
52+
"DurationSeconds": 3600, # 1 hour
53+
}
54+
55+
if external_id:
56+
assume_role_kwargs["ExternalId"] = external_id
57+
58+
response = sts_client.assume_role(**assume_role_kwargs)
59+
credentials = response["Credentials"]
60+
61+
logger.info(
62+
"Successfully assumed role %s, expires at %s",
63+
role_arn,
64+
credentials["Expiration"],
65+
)
66+
67+
return {
68+
"access_key": credentials["AccessKeyId"],
69+
"secret_key": credentials["SecretAccessKey"],
70+
"token": credentials["SessionToken"],
71+
"expiry_time": credentials["Expiration"].timestamp(),
72+
}
73+
74+
75+
def _get_cached_credentials() -> Optional[dict]:
76+
"""Get cached credentials if they're still valid."""
77+
with _credentials_cache["lock"]:
78+
if _credentials_cache["credentials"] is None:
79+
return None
80+
81+
# Check if credentials are about to expire
82+
if time.time() >= _credentials_cache["expiration"] - CREDENTIALS_REFRESH_BUFFER_SECONDS:
83+
return None
84+
85+
return _credentials_cache["credentials"]
86+
87+
88+
def _set_cached_credentials(credentials: dict) -> None:
89+
"""Cache the credentials."""
90+
with _credentials_cache["lock"]:
91+
_credentials_cache["credentials"] = credentials
92+
_credentials_cache["expiration"] = credentials["expiry_time"]
93+
94+
95+
def get_aws_credentials() -> Optional[dict]:
96+
"""
97+
Get AWS credentials, using assume role if configured.
98+
99+
Returns:
100+
dict with 'access_key', 'secret_key', 'token' (optional), and 'region',
101+
or None if no credentials are configured/needed.
102+
"""
103+
region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
104+
105+
# If assume role is configured, use it
106+
if Config.AWS_ASSUME_ROLE_ARN:
107+
# Try to use cached credentials
108+
cached = _get_cached_credentials()
109+
if cached:
110+
return {
111+
"access_key": cached["access_key"],
112+
"secret_key": cached["secret_key"],
113+
"token": cached["token"],
114+
"region": region,
115+
}
116+
117+
# Assume role and cache credentials
118+
try:
119+
credentials = _assume_role()
120+
_set_cached_credentials(credentials)
121+
return {
122+
"access_key": credentials["access_key"],
123+
"secret_key": credentials["secret_key"],
124+
"token": credentials["token"],
125+
"region": region,
126+
}
127+
except Exception as e:
128+
logger.error("Failed to assume role %s: %s", Config.AWS_ASSUME_ROLE_ARN, e)
129+
raise
130+
131+
# If static credentials are provided via environment variables
132+
access_key = os.getenv("AWS_ACCESS_KEY_ID")
133+
secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
134+
session_token = os.getenv("AWS_SESSION_TOKEN")
135+
136+
if access_key and secret_key:
137+
result = {
138+
"access_key": access_key,
139+
"secret_key": secret_key,
140+
"region": region,
141+
}
142+
if session_token:
143+
result["token"] = session_token
144+
return result
145+
146+
# Return None to use default credential chain (instance profile, etc.)
147+
return None
148+
149+
150+
def get_boto3_session() -> boto3.Session:
151+
"""
152+
Get a boto3 session with the appropriate credentials.
153+
154+
This handles assume role if configured, otherwise uses the default credential chain.
155+
"""
156+
credentials = get_aws_credentials()
157+
158+
if credentials:
159+
return boto3.Session(
160+
aws_access_key_id=credentials["access_key"],
161+
aws_secret_access_key=credentials["secret_key"],
162+
aws_session_token=credentials.get("token"),
163+
region_name=credentials.get("region"),
164+
)
165+
166+
# Use default credential chain
167+
return boto3.Session(
168+
region_name=os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
169+
)
170+
171+
172+
def get_bedrock_credentials_kwargs() -> dict:
173+
"""
174+
Get kwargs to pass to ChatBedrock or BedrockEmbeddings for credentials.
175+
176+
Returns a dict that can be unpacked into the constructor.
177+
"""
178+
credentials = get_aws_credentials()
179+
region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION")
180+
181+
kwargs = {}
182+
183+
if region:
184+
kwargs["region_name"] = region
185+
186+
if credentials:
187+
kwargs["credentials_profile_name"] = None # Disable profile lookup
188+
kwargs["aws_access_key_id"] = credentials["access_key"]
189+
kwargs["aws_secret_access_key"] = credentials["secret_key"]
190+
if credentials.get("token"):
191+
kwargs["aws_session_token"] = credentials["token"]
192+
193+
return kwargs

services/chatbot/src/chatbot/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ class Config:
2727
AZURE_OPENAI_CHAT_DEPLOYMENT = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT")
2828
AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT")
2929
AWS_BEARER_TOKEN_BEDROCK = os.getenv("AWS_BEARER_TOKEN_BEDROCK")
30+
AWS_ASSUME_ROLE_ARN = os.getenv("AWS_ASSUME_ROLE_ARN", "")
31+
AWS_EXTERNAL_ID = os.getenv("AWS_EXTERNAL_ID", "")
32+
AWS_ROLE_SESSION_NAME = os.getenv("AWS_ROLE_SESSION_NAME", "crapi-chatbot-session")
3033
VERTEX_PROJECT = os.getenv("VERTEX_PROJECT", "")
3134
VERTEX_LOCATION = os.getenv("VERTEX_LOCATION", "")
3235
MAX_CONTENT_LENGTH = int(os.getenv("MAX_CONTENT_LENGTH", 50000))

services/chatbot/src/chatbot/langgraph_agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from langchain_openai import AzureChatOpenAI, ChatOpenAI
1313

1414
from .agent_utils import truncate_tool_messages
15+
from .aws_credentials import get_bedrock_credentials_kwargs
1516
from .config import Config
1617
from .extensions import postgresdb
1718
from .mcp_client import get_mcp_client
@@ -57,7 +58,8 @@ def _build_llm(api_key, model_name):
5758
kwargs["api_key"] = Config.AZURE_OPENAI_API_KEY
5859
return AzureChatOpenAI(**kwargs)
5960
if provider == "bedrock":
60-
return ChatBedrock(model_id=model_name)
61+
bedrock_kwargs = get_bedrock_credentials_kwargs()
62+
return ChatBedrock(model_id=model_name, **bedrock_kwargs)
6163
if provider == "vertex":
6264
return ChatVertexAI(
6365
model_name=model_name,

services/chatbot/src/chatbot/retriever_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from langchain_mistralai import MistralAIEmbeddings
1212
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
1313

14+
from .aws_credentials import get_bedrock_credentials_kwargs
1415
from .config import Config
1516

1617
logger = logging.getLogger(__name__)
@@ -103,7 +104,8 @@ def get_embedding_function(api_key, provider: str, llm_model: str | None):
103104
if not model_id:
104105
logger.warning("Bedrock embedding model not configured.")
105106
return _zero_embeddings()
106-
return BedrockEmbeddings(model_id=model_id)
107+
bedrock_kwargs = get_bedrock_credentials_kwargs()
108+
return BedrockEmbeddings(model_id=model_id, **bedrock_kwargs)
107109
if embeddings_provider == "vertex":
108110
vertex_model = (
109111
Config.EMBEDDINGS_MODEL

0 commit comments

Comments
 (0)