Skip to content

Commit ba82360

Browse files
authored
Merge pull request #125 from SAP/re-organize-backend
Re organize backend
2 parents e7efc62 + 73cdcc7 commit ba82360

5 files changed

Lines changed: 354 additions & 304 deletions

File tree

backend-agent/app/routes.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
import json
2+
import os
3+
4+
from flask import request, jsonify, send_file, abort
5+
from sqlalchemy import select
6+
7+
from app.db.models import Attack, ModelAttackScore, TargetModel, db
8+
from app.utils import send_intro, verify_api_key
9+
from attack_result import SuiteResult
10+
from services import run_all_attacks
11+
from status import status
12+
13+
14+
def register_routes(app, sock, agent=None, callbacks=None):
15+
# ----------------------
16+
# Health endpoints
17+
# ----------------------
18+
@app.route("/health")
19+
def check_health():
20+
"""
21+
Health route is used in the CI to test that the installation was
22+
successful.
23+
"""
24+
return jsonify({'status': 'ok'})
25+
26+
# ----------------------
27+
# Attacks endpoints
28+
# ----------------------
29+
@app.route('/run_all', methods=['POST'])
30+
def execute_all_attacks():
31+
"""
32+
Run all attacks. Used for automation.
33+
Expected JSON body:
34+
{
35+
"target": "string"
36+
}
37+
"""
38+
verify_api_key()
39+
data = request.get_json()
40+
target_model = data.get('target') if data else None
41+
if not target_model:
42+
return jsonify({'error': 'target parameter is required'}), 400
43+
# Call the service to run all attacks
44+
result = run_all_attacks(
45+
target=target_model
46+
)
47+
return jsonify(result), 200 if result.get('success') else 500
48+
49+
@app.route('/api/attacks', methods=['GET'])
50+
def get_attacks():
51+
"""
52+
Endpoint to retrieve all attacks with their weights.
53+
Returns a JSON object with attack names and their weights.
54+
"""
55+
try:
56+
attacks = db.session.query(Attack).all()
57+
attack_list = [
58+
{'name': attack.name, 'weight': attack.weight}
59+
for attack in attacks
60+
]
61+
return jsonify(attack_list), 200
62+
except Exception as e:
63+
return jsonify({'error': str(e)}), 500
64+
65+
@app.route('/api/attacks', methods=['PUT'])
66+
def update_attack_weights():
67+
"""
68+
Update weights for multiple attacks.
69+
Expects a JSON object like: {"artPrompt": 2, "codeAttack": 1, ...}
70+
"""
71+
verify_api_key()
72+
try:
73+
weights = request.get_json()
74+
if not isinstance(weights, dict):
75+
return jsonify({'error': 'Invalid payload format'}), 400
76+
77+
for name, weight in weights.items():
78+
attack = db.session.query(Attack).filter_by(name=name).first()
79+
if attack:
80+
attack.weight = float(weight)
81+
else:
82+
return jsonify({'error': f'Attack not found: {name}'}), 404
83+
84+
db.session.commit()
85+
return jsonify({'message': 'Weights updated successfully'}), 200
86+
87+
except Exception as e:
88+
db.session.rollback()
89+
return jsonify({'error': str(e)}), 500
90+
91+
# ----------------------
92+
# Reports endpoints
93+
# ----------------------
94+
@app.route('/download_report')
95+
def download_report():
96+
"""
97+
This route allows to download attack suite reports by specifying
98+
their name.
99+
"""
100+
name = request.args.get('name')
101+
format = request.args.get('format', 'md')
102+
103+
# Ensure that a name is provided
104+
if not name:
105+
abort(400)
106+
# Ensure that only allowed chars are in the filename
107+
# (e.g. no path traversal)
108+
if not all([c in SuiteResult.FILENAME_ALLOWED_CHARS for c in name]):
109+
abort(400)
110+
111+
results = SuiteResult.load_from_name(name)
112+
113+
generated_name = name + '_generated'
114+
path = os.path.join(SuiteResult.DEFAULT_OUTPUT_PATH, generated_name)
115+
result_path = results.to_file(path, format)
116+
return send_file(
117+
result_path,
118+
mimetype=SuiteResult.get_mime_type(format)
119+
)
120+
121+
@app.route('/api/heatmap', methods=['GET'])
122+
def get_heatmap():
123+
"""
124+
Endpoint to retrieve heatmap data showing model score
125+
against various attacks.
126+
127+
Queries the database for total attacks and successes per target model
128+
and attack combination.
129+
Calculates attack success rate and returns structured data for
130+
visualization.
131+
132+
Returns:
133+
JSON response with:
134+
- models: List of target models and their attack success rate
135+
per attack.
136+
- attacks: List of attack names and their associated weights.
137+
138+
HTTP Status Codes:
139+
200: Data successfully retrieved.
140+
500: Internal server error during query execution.
141+
"""
142+
try:
143+
query = (
144+
select(
145+
ModelAttackScore.total_number_of_attack,
146+
ModelAttackScore.total_success,
147+
TargetModel.name.label('attack_model_name'),
148+
Attack.name.label('attack_name'),
149+
Attack.weight.label('attack_weight')
150+
)
151+
.join(TargetModel, ModelAttackScore.target_model_id == TargetModel.id) # noqa: E501
152+
.join(Attack, ModelAttackScore.attack_id == Attack.id)
153+
)
154+
155+
scores = db.session.execute(query).all()
156+
all_models = {}
157+
all_attacks = {}
158+
159+
for score in scores:
160+
model_name = score.attack_model_name
161+
attack_name = score.attack_name
162+
163+
if attack_name not in all_attacks:
164+
all_attacks[attack_name] = score.attack_weight
165+
166+
if model_name not in all_models:
167+
all_models[model_name] = {
168+
'name': model_name,
169+
'scores': {},
170+
}
171+
172+
# Compute attack success rate for this model/attack
173+
success_ratio = (
174+
round((score.total_success / score.total_number_of_attack) * 100) # noqa: E501
175+
if score.total_number_of_attack else 0
176+
)
177+
178+
all_models[model_name]['scores'][attack_name] = success_ratio
179+
180+
return jsonify({
181+
'models': list(all_models.values()),
182+
'attacks': [
183+
{'name': name, 'weight': weight}
184+
for name, weight in sorted(all_attacks.items())
185+
]
186+
})
187+
except Exception as e:
188+
return jsonify({'error': str(e)}), 500
189+
190+
# ----------------------
191+
# WebSocket endpoints
192+
# ----------------------
193+
@sock.route('/agent')
194+
def query_agent(sock):
195+
"""
196+
Websocket route for the frontend to send prompts to the agent and
197+
receive responses as well as status updates.
198+
199+
Messages received are in this JSON format:
200+
{
201+
"type":"message",
202+
"data":"Start the vulnerability scan",
203+
"key":"secretapikey"
204+
}
205+
"""
206+
# Verify API key from headers before establishing session
207+
verify_api_key()
208+
if not agent:
209+
sock.send(json.dumps({
210+
'type': 'message',
211+
'data': 'Agent is disabled on this deployment.'
212+
}))
213+
return
214+
status.sock = sock
215+
# Intro is sent after connecting successfully
216+
send_intro(sock)
217+
while True:
218+
try:
219+
data_raw = sock.receive()
220+
data = json.loads(data_raw)
221+
assert 'data' in data
222+
query = data['data']
223+
status.clear_report()
224+
response = agent.invoke(
225+
{'input': query},
226+
config=callbacks or {}
227+
)
228+
ai_response = response['output']
229+
formatted_output = {
230+
'type': 'message',
231+
'data': (
232+
f'{ai_response}'
233+
)
234+
}
235+
sock.send(json.dumps(formatted_output))
236+
except json.JSONDecodeError:
237+
sock.send(json.dumps({
238+
'type': 'error',
239+
'data': 'Invalid JSON format'
240+
}))
241+
except Exception as e:
242+
sock.send(json.dumps({
243+
'type': 'error',
244+
'data': f'Error: {str(e)}'
245+
}))

backend-agent/app/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import json
2+
import os
3+
4+
from flask import request, abort
5+
6+
7+
def send_intro(sock):
8+
"""
9+
Sends the intro via the websocket connection.
10+
11+
The intro is meant as a short tutorial on how to use the agent.
12+
Also it includes meaningful suggestions for prompts that should
13+
result in predictable behavior for the agent, e.g.
14+
"Start the vulnerability scan".
15+
"""
16+
intro_file = 'data/intro.txt'
17+
try:
18+
with open(intro_file, 'r') as f:
19+
intro = f.read()
20+
except FileNotFoundError:
21+
intro = "Welcome! (intro file missing)"
22+
sock.send(json.dumps({'type': 'message', 'data': intro}))
23+
24+
25+
def verify_api_key():
26+
"""
27+
Verifies the API key from the request headers against the env variable.
28+
If no API key is configured, access is allowed.
29+
If API key is configured but missing/invalid, request is rejected.
30+
"""
31+
if os.getenv('API_KEY'):
32+
provided_key = request.headers.get('X-API-Key')
33+
if provided_key != os.getenv('API_KEY'):
34+
abort(403)

backend-agent/cli.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import sys
55
from argparse import ArgumentParser, Namespace
6-
from pathlib import Path
76
from typing import Callable
87

98
from attack import AttackSpecification, AttackSuite
@@ -17,6 +16,7 @@
1716
test as test_textattack,
1817
)
1918
from llm import LLM
19+
from services import run_all_attacks
2020
from status import Trace
2121

2222
# Library-free Subcommand utilities from
@@ -385,31 +385,9 @@ def run(args):
385385
required=True),
386386
])
387387
def run_all(args):
388-
"""Run all LLM attacks with specified target and evaluation models."""
389-
default_spec_path = Path('data/all/default.json')
390-
try:
391-
with default_spec_path.open("r") as f:
392-
spec = json.load(f)
393-
except FileNotFoundError:
394-
print(f'File not found: {args.file}', file=sys.stderr)
395-
return
396-
except json.JSONDecodeError as e:
397-
print(f'Invalid JSON format: {e}', file=sys.stderr)
398-
return
399-
except PermissionError:
400-
print(f'Permission denied reading file: {args.file}', file=sys.stderr)
401-
return
402-
if 'attacks' in spec:
403-
suite = AttackSuite.from_dict(spec)
404-
suite.set_target(args.target)
405-
results = suite.run()
406-
result_return = {'success': True, 'results': results}
407-
else:
408-
result_return = {
409-
'success': False,
410-
'error': 'JSON is invalid. No attacks run.'
411-
}
412-
return result_return
388+
"""Run all LLM attacks with specified target."""
389+
result = run_all_attacks(target=args.target)
390+
return result
413391

414392

415393
@subcommand()

0 commit comments

Comments
 (0)