|
| 1 | +from pathlib import Path |
| 2 | + |
| 3 | +import yaml |
| 4 | + |
| 5 | +from agentic_security.attack_rules.models import AttackRule, AttackRuleSeverity |
| 6 | +from agentic_security.logutils import logger |
| 7 | + |
| 8 | + |
| 9 | +class RuleValidationError(Exception): |
| 10 | + pass |
| 11 | + |
| 12 | + |
| 13 | +class RuleLoader: |
| 14 | + REQUIRED_FIELDS = {"name", "prompt"} |
| 15 | + VALID_EXTENSIONS = {".yaml", ".yml"} |
| 16 | + |
| 17 | + def __init__(self, rules_dir: str | Path | None = None): |
| 18 | + self.rules_dir = Path(rules_dir) if rules_dir else None |
| 19 | + self._rules: list[AttackRule] = [] |
| 20 | + |
| 21 | + def validate_rule_data(self, data: dict, filepath: str | None = None) -> list[str]: |
| 22 | + errors = [] |
| 23 | + for field in self.REQUIRED_FIELDS: |
| 24 | + if field not in data or not data[field]: |
| 25 | + errors.append(f"Missing required field: {field}") |
| 26 | + |
| 27 | + if "severity" in data and data["severity"]: |
| 28 | + if data["severity"].lower() not in {"low", "medium", "high"}: |
| 29 | + errors.append(f"Invalid severity: {data['severity']}") |
| 30 | + |
| 31 | + if filepath: |
| 32 | + errors = [f"{filepath}: {e}" for e in errors] |
| 33 | + return errors |
| 34 | + |
| 35 | + def load_rule_from_file(self, filepath: str | Path) -> AttackRule | None: |
| 36 | + filepath = Path(filepath) |
| 37 | + if filepath.suffix.lower() not in self.VALID_EXTENSIONS: |
| 38 | + return None |
| 39 | + |
| 40 | + try: |
| 41 | + with open(filepath, encoding="utf-8") as f: |
| 42 | + data = yaml.safe_load(f) |
| 43 | + |
| 44 | + if not isinstance(data, dict): |
| 45 | + logger.warning(f"Invalid YAML structure in {filepath}") |
| 46 | + return None |
| 47 | + |
| 48 | + errors = self.validate_rule_data(data, str(filepath)) |
| 49 | + if errors: |
| 50 | + for error in errors: |
| 51 | + logger.warning(error) |
| 52 | + return None |
| 53 | + |
| 54 | + rule = AttackRule.from_dict(data) |
| 55 | + rule.metadata["source_file"] = str(filepath) |
| 56 | + return rule |
| 57 | + |
| 58 | + except yaml.YAMLError as e: |
| 59 | + logger.error(f"YAML parsing error in {filepath}: {e}") |
| 60 | + return None |
| 61 | + except Exception as e: |
| 62 | + logger.error(f"Error loading rule from {filepath}: {e}") |
| 63 | + return None |
| 64 | + |
| 65 | + def load_rule_from_string(self, yaml_content: str) -> AttackRule | None: |
| 66 | + try: |
| 67 | + data = yaml.safe_load(yaml_content) |
| 68 | + if not isinstance(data, dict): |
| 69 | + return None |
| 70 | + |
| 71 | + errors = self.validate_rule_data(data) |
| 72 | + if errors: |
| 73 | + for error in errors: |
| 74 | + logger.warning(error) |
| 75 | + return None |
| 76 | + |
| 77 | + return AttackRule.from_dict(data) |
| 78 | + except yaml.YAMLError as e: |
| 79 | + logger.error(f"YAML parsing error: {e}") |
| 80 | + return None |
| 81 | + |
| 82 | + def load_rules_from_directory( |
| 83 | + self, directory: str | Path | None = None, recursive: bool = True |
| 84 | + ) -> list[AttackRule]: |
| 85 | + directory = Path(directory) if directory else self.rules_dir |
| 86 | + if not directory or not directory.exists(): |
| 87 | + logger.warning(f"Rules directory does not exist: {directory}") |
| 88 | + return [] |
| 89 | + |
| 90 | + rules = [] |
| 91 | + # pattern = "**/*.yaml" if recursive else "*.yaml" |
| 92 | + |
| 93 | + for ext in [".yaml", ".yml"]: |
| 94 | + glob_pattern = f"**/*{ext}" if recursive else f"*{ext}" |
| 95 | + for filepath in directory.glob(glob_pattern): |
| 96 | + rule = self.load_rule_from_file(filepath) |
| 97 | + if rule: |
| 98 | + rules.append(rule) |
| 99 | + |
| 100 | + logger.info(f"Loaded {len(rules)} rules from {directory}") |
| 101 | + self._rules.extend(rules) |
| 102 | + return rules |
| 103 | + |
| 104 | + def load_multiple_directories( |
| 105 | + self, directories: list[str | Path], recursive: bool = True |
| 106 | + ) -> list[AttackRule]: |
| 107 | + all_rules = [] |
| 108 | + for directory in directories: |
| 109 | + rules = self.load_rules_from_directory(directory, recursive) |
| 110 | + all_rules.extend(rules) |
| 111 | + return all_rules |
| 112 | + |
| 113 | + def filter_rules( |
| 114 | + self, |
| 115 | + rules: list[AttackRule] | None = None, |
| 116 | + types: list[str] | None = None, |
| 117 | + severities: list[AttackRuleSeverity] | None = None, |
| 118 | + name_pattern: str | None = None, |
| 119 | + ) -> list[AttackRule]: |
| 120 | + rules = rules if rules is not None else self._rules |
| 121 | + result = rules |
| 122 | + |
| 123 | + if types: |
| 124 | + result = [r for r in result if r.type in types] |
| 125 | + |
| 126 | + if severities: |
| 127 | + result = [r for r in result if r.severity in severities] |
| 128 | + |
| 129 | + if name_pattern: |
| 130 | + import re |
| 131 | + |
| 132 | + pattern = re.compile(name_pattern, re.IGNORECASE) |
| 133 | + result = [r for r in result if pattern.search(r.name)] |
| 134 | + |
| 135 | + return result |
| 136 | + |
| 137 | + def get_rules_by_type(self, rule_type: str) -> list[AttackRule]: |
| 138 | + return self.filter_rules(types=[rule_type]) |
| 139 | + |
| 140 | + def get_rules_by_severity(self, severity: AttackRuleSeverity) -> list[AttackRule]: |
| 141 | + return self.filter_rules(severities=[severity]) |
| 142 | + |
| 143 | + @property |
| 144 | + def rules(self) -> list[AttackRule]: |
| 145 | + return self._rules |
| 146 | + |
| 147 | + @property |
| 148 | + def rule_types(self) -> set[str]: |
| 149 | + return {r.type for r in self._rules} |
| 150 | + |
| 151 | + |
| 152 | +def load_rules_from_directory( |
| 153 | + directory: str | Path, recursive: bool = True |
| 154 | +) -> list[AttackRule]: |
| 155 | + loader = RuleLoader() |
| 156 | + return loader.load_rules_from_directory(directory, recursive) |
0 commit comments