Skip to content

Commit 796bd33

Browse files
authored
Merge pull request #276 from msoedov/feat/research-enhancements
Feat/research enhancements
2 parents 433c999 + bc7fdd7 commit 796bd33

57 files changed

Lines changed: 3561 additions & 1 deletion

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from agentic_security.attack_rules.models import AttackRule, AttackRuleSeverity
2+
from agentic_security.attack_rules.loader import RuleLoader, load_rules_from_directory
3+
from agentic_security.attack_rules.dataset import (
4+
rules_to_dataset,
5+
load_rules_as_dataset,
6+
YAMLRulesDatasetLoader,
7+
)
8+
9+
__all__ = [
10+
"AttackRule",
11+
"AttackRuleSeverity",
12+
"RuleLoader",
13+
"load_rules_from_directory",
14+
"rules_to_dataset",
15+
"load_rules_as_dataset",
16+
"YAMLRulesDatasetLoader",
17+
]
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from pathlib import Path
2+
3+
from agentic_security.attack_rules.loader import RuleLoader
4+
from agentic_security.attack_rules.models import AttackRule, AttackRuleSeverity
5+
from agentic_security.probe_data.models import ProbeDataset
6+
7+
8+
def rules_to_dataset(
9+
rules: list[AttackRule],
10+
name: str = "YAML Rules",
11+
variables: dict[str, str] | None = None,
12+
) -> ProbeDataset:
13+
prompts = [rule.render_prompt(variables) for rule in rules]
14+
tokens = sum(len(p.split()) for p in prompts)
15+
16+
return ProbeDataset(
17+
dataset_name=name,
18+
metadata={
19+
"source": "yaml_rules",
20+
"rule_count": len(rules),
21+
"types": list({r.type for r in rules}),
22+
},
23+
prompts=prompts,
24+
tokens=tokens,
25+
approx_cost=0.0,
26+
)
27+
28+
29+
def load_rules_as_dataset(
30+
directory: str | Path,
31+
types: list[str] | None = None,
32+
severities: list[str] | None = None,
33+
recursive: bool = True,
34+
variables: dict[str, str] | None = None,
35+
) -> ProbeDataset:
36+
loader = RuleLoader()
37+
rules = loader.load_rules_from_directory(directory, recursive)
38+
39+
severity_enums = None
40+
if severities:
41+
severity_enums = [AttackRuleSeverity.from_string(s) for s in severities]
42+
43+
filtered = loader.filter_rules(rules, types=types, severities=severity_enums)
44+
45+
name = f"YAML Rules ({Path(directory).name})"
46+
if types:
47+
name = f"YAML Rules [{', '.join(types)}]"
48+
49+
return rules_to_dataset(filtered, name=name, variables=variables)
50+
51+
52+
class YAMLRulesDatasetLoader:
53+
def __init__(
54+
self,
55+
directories: list[str | Path] | None = None,
56+
types: list[str] | None = None,
57+
severities: list[str] | None = None,
58+
recursive: bool = True,
59+
):
60+
self.directories = directories or []
61+
self.types = types
62+
self.severities = severities
63+
self.recursive = recursive
64+
self._loader = RuleLoader()
65+
66+
def add_directory(self, directory: str | Path):
67+
self.directories.append(directory)
68+
69+
def add_builtin_rules(self, rules_subdir: str = "rules"):
70+
builtin = Path(__file__).parent / rules_subdir
71+
if builtin.exists():
72+
self.directories.append(builtin)
73+
74+
def load(self, variables: dict[str, str] | None = None) -> list[ProbeDataset]:
75+
datasets = []
76+
77+
for directory in self.directories:
78+
directory = Path(directory)
79+
if not directory.exists():
80+
continue
81+
82+
rules = self._loader.load_rules_from_directory(directory, self.recursive)
83+
84+
severity_enums = None
85+
if self.severities:
86+
severity_enums = [
87+
AttackRuleSeverity.from_string(s) for s in self.severities
88+
]
89+
90+
filtered = self._loader.filter_rules(
91+
rules, types=self.types, severities=severity_enums
92+
)
93+
94+
if not filtered:
95+
continue
96+
97+
dataset = rules_to_dataset(
98+
filtered,
99+
name=f"YAML Rules ({directory.name})",
100+
variables=variables,
101+
)
102+
datasets.append(dataset)
103+
104+
return datasets
105+
106+
def load_merged(self, variables: dict[str, str] | None = None) -> ProbeDataset:
107+
all_rules = []
108+
109+
for directory in self.directories:
110+
directory = Path(directory)
111+
if not directory.exists():
112+
continue
113+
rules = self._loader.load_rules_from_directory(directory, self.recursive)
114+
all_rules.extend(rules)
115+
116+
severity_enums = None
117+
if self.severities:
118+
severity_enums = [
119+
AttackRuleSeverity.from_string(s) for s in self.severities
120+
]
121+
122+
filtered = self._loader.filter_rules(
123+
all_rules, types=self.types, severities=severity_enums
124+
)
125+
126+
return rules_to_dataset(
127+
filtered, name="YAML Rules (merged)", variables=variables
128+
)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from pathlib import Path
2+
3+
import yaml
4+
5+
from agentic_security.attack_rules.models import AttackRule, AttackRuleSeverity
6+
from agentic_security.logutils import logger
7+
8+
9+
class RuleValidationError(Exception):
10+
pass
11+
12+
13+
class RuleLoader:
14+
REQUIRED_FIELDS = {"name", "prompt"}
15+
VALID_EXTENSIONS = {".yaml", ".yml"}
16+
17+
def __init__(self, rules_dir: str | Path | None = None):
18+
self.rules_dir = Path(rules_dir) if rules_dir else None
19+
self._rules: list[AttackRule] = []
20+
21+
def validate_rule_data(self, data: dict, filepath: str | None = None) -> list[str]:
22+
errors = []
23+
for field in self.REQUIRED_FIELDS:
24+
if field not in data or not data[field]:
25+
errors.append(f"Missing required field: {field}")
26+
27+
if "severity" in data and data["severity"]:
28+
if data["severity"].lower() not in {"low", "medium", "high"}:
29+
errors.append(f"Invalid severity: {data['severity']}")
30+
31+
if filepath:
32+
errors = [f"{filepath}: {e}" for e in errors]
33+
return errors
34+
35+
def load_rule_from_file(self, filepath: str | Path) -> AttackRule | None:
36+
filepath = Path(filepath)
37+
if filepath.suffix.lower() not in self.VALID_EXTENSIONS:
38+
return None
39+
40+
try:
41+
with open(filepath, encoding="utf-8") as f:
42+
data = yaml.safe_load(f)
43+
44+
if not isinstance(data, dict):
45+
logger.warning(f"Invalid YAML structure in {filepath}")
46+
return None
47+
48+
errors = self.validate_rule_data(data, str(filepath))
49+
if errors:
50+
for error in errors:
51+
logger.warning(error)
52+
return None
53+
54+
rule = AttackRule.from_dict(data)
55+
rule.metadata["source_file"] = str(filepath)
56+
return rule
57+
58+
except yaml.YAMLError as e:
59+
logger.error(f"YAML parsing error in {filepath}: {e}")
60+
return None
61+
except Exception as e:
62+
logger.error(f"Error loading rule from {filepath}: {e}")
63+
return None
64+
65+
def load_rule_from_string(self, yaml_content: str) -> AttackRule | None:
66+
try:
67+
data = yaml.safe_load(yaml_content)
68+
if not isinstance(data, dict):
69+
return None
70+
71+
errors = self.validate_rule_data(data)
72+
if errors:
73+
for error in errors:
74+
logger.warning(error)
75+
return None
76+
77+
return AttackRule.from_dict(data)
78+
except yaml.YAMLError as e:
79+
logger.error(f"YAML parsing error: {e}")
80+
return None
81+
82+
def load_rules_from_directory(
83+
self, directory: str | Path | None = None, recursive: bool = True
84+
) -> list[AttackRule]:
85+
directory = Path(directory) if directory else self.rules_dir
86+
if not directory or not directory.exists():
87+
logger.warning(f"Rules directory does not exist: {directory}")
88+
return []
89+
90+
rules = []
91+
# pattern = "**/*.yaml" if recursive else "*.yaml"
92+
93+
for ext in [".yaml", ".yml"]:
94+
glob_pattern = f"**/*{ext}" if recursive else f"*{ext}"
95+
for filepath in directory.glob(glob_pattern):
96+
rule = self.load_rule_from_file(filepath)
97+
if rule:
98+
rules.append(rule)
99+
100+
logger.info(f"Loaded {len(rules)} rules from {directory}")
101+
self._rules.extend(rules)
102+
return rules
103+
104+
def load_multiple_directories(
105+
self, directories: list[str | Path], recursive: bool = True
106+
) -> list[AttackRule]:
107+
all_rules = []
108+
for directory in directories:
109+
rules = self.load_rules_from_directory(directory, recursive)
110+
all_rules.extend(rules)
111+
return all_rules
112+
113+
def filter_rules(
114+
self,
115+
rules: list[AttackRule] | None = None,
116+
types: list[str] | None = None,
117+
severities: list[AttackRuleSeverity] | None = None,
118+
name_pattern: str | None = None,
119+
) -> list[AttackRule]:
120+
rules = rules if rules is not None else self._rules
121+
result = rules
122+
123+
if types:
124+
result = [r for r in result if r.type in types]
125+
126+
if severities:
127+
result = [r for r in result if r.severity in severities]
128+
129+
if name_pattern:
130+
import re
131+
132+
pattern = re.compile(name_pattern, re.IGNORECASE)
133+
result = [r for r in result if pattern.search(r.name)]
134+
135+
return result
136+
137+
def get_rules_by_type(self, rule_type: str) -> list[AttackRule]:
138+
return self.filter_rules(types=[rule_type])
139+
140+
def get_rules_by_severity(self, severity: AttackRuleSeverity) -> list[AttackRule]:
141+
return self.filter_rules(severities=[severity])
142+
143+
@property
144+
def rules(self) -> list[AttackRule]:
145+
return self._rules
146+
147+
@property
148+
def rule_types(self) -> set[str]:
149+
return {r.type for r in self._rules}
150+
151+
152+
def load_rules_from_directory(
153+
directory: str | Path, recursive: bool = True
154+
) -> list[AttackRule]:
155+
loader = RuleLoader()
156+
return loader.load_rules_from_directory(directory, recursive)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from dataclasses import dataclass, field
2+
from enum import Enum
3+
from typing import Any
4+
5+
6+
class AttackRuleSeverity(Enum):
7+
LOW = "low"
8+
MEDIUM = "medium"
9+
HIGH = "high"
10+
11+
@classmethod
12+
def from_string(cls, value: str) -> "AttackRuleSeverity":
13+
try:
14+
return cls(value.lower())
15+
except ValueError:
16+
return cls.MEDIUM
17+
18+
19+
@dataclass
20+
class AttackRule:
21+
name: str
22+
type: str
23+
prompt: str
24+
severity: AttackRuleSeverity = AttackRuleSeverity.MEDIUM
25+
pass_conditions: list[str] = field(default_factory=list)
26+
fail_conditions: list[str] = field(default_factory=list)
27+
source: str | None = None
28+
metadata: dict[str, Any] = field(default_factory=dict)
29+
30+
@classmethod
31+
def from_dict(cls, data: dict[str, Any]) -> "AttackRule":
32+
severity = AttackRuleSeverity.from_string(data.get("severity", "medium"))
33+
return cls(
34+
name=data["name"],
35+
type=data.get("type", "unknown"),
36+
prompt=data["prompt"],
37+
severity=severity,
38+
pass_conditions=data.get("pass_conditions", []),
39+
fail_conditions=data.get("fail_conditions", []),
40+
source=data.get("source"),
41+
metadata={
42+
k: v
43+
for k, v in data.items()
44+
if k
45+
not in {
46+
"name",
47+
"type",
48+
"prompt",
49+
"severity",
50+
"pass_conditions",
51+
"fail_conditions",
52+
"source",
53+
}
54+
},
55+
)
56+
57+
def to_dict(self) -> dict[str, Any]:
58+
result = {
59+
"name": self.name,
60+
"type": self.type,
61+
"prompt": self.prompt,
62+
"severity": self.severity.value,
63+
}
64+
if self.pass_conditions:
65+
result["pass_conditions"] = self.pass_conditions
66+
if self.fail_conditions:
67+
result["fail_conditions"] = self.fail_conditions
68+
if self.source:
69+
result["source"] = self.source
70+
if self.metadata:
71+
result.update(self.metadata)
72+
return result
73+
74+
def render_prompt(self, variables: dict[str, str] | None = None) -> str:
75+
if not variables:
76+
return self.prompt
77+
result = self.prompt
78+
for key, value in variables.items():
79+
result = result.replace(f"{{{key}}}", value)
80+
result = result.replace(f"{{{{ {key} }}}}", value)
81+
return result

0 commit comments

Comments
 (0)