-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Expand file tree
/
Copy pathKB_Updater.py
More file actions
146 lines (124 loc) · 5.31 KB
/
KB_Updater.py
File metadata and controls
146 lines (124 loc) · 5.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import boto3
import json
import datetime
import time
class DateTimeEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.isoformat()
return super().default(obj)
def get_knowledge_base_id(knowledge_base_name, region_name, bedrock_agent):
response = bedrock_agent.list_knowledge_bases()
for kb in response['knowledgeBaseSummaries']:
if kb['name'] == knowledge_base_name:
return kb['knowledgeBaseId']
raise ValueError(f"Knowledge base '{knowledge_base_name}' not found")
def get_or_create_data_source(knowledge_base_id, language, region_name, bedrock_agent):
# List existing data sources
response = bedrock_agent.list_data_sources(knowledgeBaseId=knowledge_base_id)
data_sources = response['dataSourceSummaries']
# Look for existing data source for this SDK
for ds in data_sources:
if language in ds['name'] and ds['name'] != "default":
return ds['dataSourceId'], ds['name'], False # Found existing
if language in ["steering-docs", "final-specs"]:
ds_name=f"{language}-data-source"
bucket_name = f"{language}-bucket"
else:
ds_name=f"{language}-premium-data-source"
bucket_name = f"{language}-premium-bucket"
# Create new data source if none found
response = bedrock_agent.create_data_source(
knowledgeBaseId=knowledge_base_id,
name=ds_name,
dataSourceConfiguration={
"type": "S3",
"s3Configuration": {
"bucketArn": f"arn:aws:s3:::{bucket_name}"
}
},
vectorIngestionConfiguration = {
"chunkingConfiguration": {
"chunkingStrategy": "HIERARCHICAL",
"hierarchicalChunkingConfiguration": {
"levelConfigurations": [
{
"maxTokens": 1500
},
{
"maxTokens": 300
}
],
"overlapTokens": 75
}
}
}
)
return response['dataSource']['dataSourceId'], response['dataSource']['name'], True # Created new
def sync_data_source(knowledge_base_id, data_source_id, region_name, bedrock_agent):
response = bedrock_agent.start_ingestion_job(
knowledgeBaseId=knowledge_base_id,
dataSourceId=data_source_id
)
return response
def monitor_ingestion_job(knowledge_base_id, data_source_id, ingestion_job_id, region_name, bedrock_agent):
max_attempts = 100
attempts = 0
while attempts < max_attempts:
job_status = bedrock_agent.get_ingestion_job(
knowledgeBaseId=knowledge_base_id,
dataSourceId=data_source_id,
ingestionJobId=ingestion_job_id
)
status = job_status['ingestionJob']['status']
print(f"Current status: {status} - {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
if status in ['COMPLETE', 'FAILED', 'STOPPED']:
return job_status
attempts += 1
time.sleep(5)
return {"status": "TIMEOUT", "message": "Job monitoring timed out after 5 minutes"}
def lambda_handler(event, context):
language = event.get('language', 'python')
region_name = event.get('region_name', 'us-west-2')
if language in ["steering-docs", "final-specs","coding-standards"]:
knowledge_base_name = f"{language}-KB"
else:
knowledge_base_name = f"{language}-premium-KB"
bedrock_agent = boto3.client('bedrock-agent', region_name=region_name)
knowledge_base_id = get_knowledge_base_id(knowledge_base_name, region_name, bedrock_agent)
# Get or create data source
data_source_id, data_source_name, is_new = get_or_create_data_source(
knowledge_base_id, language, region_name, bedrock_agent
)
results = {
"data_source": {
"id": data_source_id,
"name": data_source_name,
"is_new": is_new
},
"ingestion_job": None,
"statistics": None
}
# Sync the data source
print(f"Syncing data source {data_source_name}...")
sync_result = sync_data_source(knowledge_base_id, data_source_id, region_name, bedrock_agent)
ingestion_job_id = sync_result['ingestionJob']['ingestionJobId']
results["ingestion_job"] = {"id": ingestion_job_id, "status": "STARTED"}
# Monitor the ingestion job
final_status = monitor_ingestion_job(
knowledge_base_id, data_source_id, ingestion_job_id, region_name, bedrock_agent
)
results["ingestion_job"]["status"] = final_status.get('ingestionJob', {}).get('status', 'UNKNOWN')
# Get statistics
if 'statistics' in final_status.get('ingestionJob', {}):
stats = final_status['ingestionJob']['statistics']
results["statistics"] = {
"documents_processed": stats.get('numberOfDocumentsScanned', 0),
"documents_failed": stats.get('numberOfDocumentsFailed', 0),
"documents_indexed": stats.get('numberOfNewDocumentsIndexed', 0),
"documents_modified_indexed": stats.get('numberOfModifiedDocumentsIndexed',0)
}
return {
'statusCode': 200,
'body': json.dumps(results, cls=DateTimeEncoder)
}