Skip to content

Commit bc7fdd7

Browse files
committed
fix(pc):
1 parent 8d42a84 commit bc7fdd7

19 files changed

Lines changed: 491 additions & 77 deletions

File tree

agentic_security/attack_rules/dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def load(self, variables: dict[str, str] | None = None) -> list[ProbeDataset]:
8383

8484
severity_enums = None
8585
if self.severities:
86-
severity_enums = [AttackRuleSeverity.from_string(s) for s in self.severities]
86+
severity_enums = [
87+
AttackRuleSeverity.from_string(s) for s in self.severities
88+
]
8789

8890
filtered = self._loader.filter_rules(
8991
rules, types=self.types, severities=severity_enums
@@ -113,10 +115,14 @@ def load_merged(self, variables: dict[str, str] | None = None) -> ProbeDataset:
113115

114116
severity_enums = None
115117
if self.severities:
116-
severity_enums = [AttackRuleSeverity.from_string(s) for s in self.severities]
118+
severity_enums = [
119+
AttackRuleSeverity.from_string(s) for s in self.severities
120+
]
117121

118122
filtered = self._loader.filter_rules(
119123
all_rules, types=self.types, severities=severity_enums
120124
)
121125

122-
return rules_to_dataset(filtered, name="YAML Rules (merged)", variables=variables)
126+
return rules_to_dataset(
127+
filtered, name="YAML Rules (merged)", variables=variables
128+
)

agentic_security/attack_rules/loader.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
from pathlib import Path
32

43
import yaml
@@ -81,17 +80,15 @@ def load_rule_from_string(self, yaml_content: str) -> AttackRule | None:
8180
return None
8281

8382
def load_rules_from_directory(
84-
self,
85-
directory: str | Path | None = None,
86-
recursive: bool = True
83+
self, directory: str | Path | None = None, recursive: bool = True
8784
) -> list[AttackRule]:
8885
directory = Path(directory) if directory else self.rules_dir
8986
if not directory or not directory.exists():
9087
logger.warning(f"Rules directory does not exist: {directory}")
9188
return []
9289

9390
rules = []
94-
pattern = "**/*.yaml" if recursive else "*.yaml"
91+
# pattern = "**/*.yaml" if recursive else "*.yaml"
9592

9693
for ext in [".yaml", ".yml"]:
9794
glob_pattern = f"**/*{ext}" if recursive else f"*{ext}"
@@ -105,9 +102,7 @@ def load_rules_from_directory(
105102
return rules
106103

107104
def load_multiple_directories(
108-
self,
109-
directories: list[str | Path],
110-
recursive: bool = True
105+
self, directories: list[str | Path], recursive: bool = True
111106
) -> list[AttackRule]:
112107
all_rules = []
113108
for directory in directories:
@@ -133,6 +128,7 @@ def filter_rules(
133128

134129
if name_pattern:
135130
import re
131+
136132
pattern = re.compile(name_pattern, re.IGNORECASE)
137133
result = [r for r in result if pattern.search(r.name)]
138134

@@ -154,8 +150,7 @@ def rule_types(self) -> set[str]:
154150

155151

156152
def load_rules_from_directory(
157-
directory: str | Path,
158-
recursive: bool = True
153+
directory: str | Path, recursive: bool = True
159154
) -> list[AttackRule]:
160155
loader = RuleLoader()
161156
return loader.load_rules_from_directory(directory, recursive)

agentic_security/attack_rules/models.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,20 @@ def from_dict(cls, data: dict[str, Any]) -> "AttackRule":
3838
pass_conditions=data.get("pass_conditions", []),
3939
fail_conditions=data.get("fail_conditions", []),
4040
source=data.get("source"),
41-
metadata={k: v for k, v in data.items() if k not in {
42-
"name", "type", "prompt", "severity",
43-
"pass_conditions", "fail_conditions", "source"
44-
}},
41+
metadata={
42+
k: v
43+
for k, v in data.items()
44+
if k
45+
not in {
46+
"name",
47+
"type",
48+
"prompt",
49+
"severity",
50+
"pass_conditions",
51+
"fail_conditions",
52+
"source",
53+
}
54+
},
4555
)
4656

4757
def to_dict(self) -> dict[str, Any]:

agentic_security/core/security.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""Security utilities and validation for agentic_security."""
2+
3+
from functools import wraps
4+
from collections.abc import Callable
5+
from urllib.parse import urlparse
6+
import hashlib
7+
import hmac
8+
import os
9+
import re
10+
11+
12+
class SecurityValidator:
13+
"""Input validation and sanitization."""
14+
15+
ALLOWED_URL_SCHEMES = {"http", "https"}
16+
MAX_URL_LENGTH = 2048
17+
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
18+
19+
@staticmethod
20+
def validate_url(url: str, allowed_hosts: list[str] | None = None) -> bool:
21+
"""Validate URL for SSRF prevention."""
22+
if len(url) > SecurityValidator.MAX_URL_LENGTH:
23+
return False
24+
25+
try:
26+
parsed = urlparse(url)
27+
28+
if parsed.scheme not in SecurityValidator.ALLOWED_URL_SCHEMES:
29+
return False
30+
31+
if not parsed.netloc:
32+
return False
33+
34+
if parsed.netloc in ["localhost", "127.0.0.1", "0.0.0.0"]:
35+
return False
36+
37+
if parsed.netloc.startswith("169.254."):
38+
return False
39+
40+
if parsed.netloc.startswith("10.") or parsed.netloc.startswith("192.168."):
41+
return False
42+
43+
if allowed_hosts and parsed.netloc not in allowed_hosts:
44+
return False
45+
46+
return True
47+
except Exception:
48+
return False
49+
50+
@staticmethod
51+
def sanitize_filename(filename: str) -> str:
52+
"""Sanitize filename to prevent path traversal."""
53+
filename = os.path.basename(filename)
54+
filename = re.sub(r"[^\w\s.-]", "", filename)
55+
filename = filename.strip()
56+
57+
if not filename or filename in [".", ".."]:
58+
raise ValueError("Invalid filename")
59+
60+
return filename
61+
62+
@staticmethod
63+
def validate_file_size(size: int) -> bool:
64+
"""Validate file size."""
65+
return 0 < size <= SecurityValidator.MAX_FILE_SIZE
66+
67+
@staticmethod
68+
def validate_csv_content(content: str) -> bool:
69+
"""Basic CSV validation."""
70+
if not content or len(content) > SecurityValidator.MAX_FILE_SIZE:
71+
return False
72+
73+
lines = content.split("\n", 2)
74+
if not lines:
75+
return False
76+
77+
return True
78+
79+
80+
class SecretManager:
81+
"""Secure secret handling."""
82+
83+
@staticmethod
84+
def get_secret(key: str, default: str | None = None) -> str | None:
85+
"""Get secret from environment."""
86+
value = os.getenv(key, default)
87+
if value and value.startswith("$"):
88+
env_key = value[1:]
89+
value = os.getenv(env_key, default)
90+
return value
91+
92+
@staticmethod
93+
def hash_secret(secret: str, salt: str | None = None) -> str:
94+
"""Hash a secret value."""
95+
if salt is None:
96+
salt = os.urandom(32).hex()
97+
98+
hashed = hashlib.pbkdf2_hmac("sha256", secret.encode(), salt.encode(), 100000)
99+
return f"{salt}${hashed.hex()}"
100+
101+
@staticmethod
102+
def verify_secret(secret: str, hashed: str) -> bool:
103+
"""Verify a secret against its hash."""
104+
try:
105+
salt, expected = hashed.split("$", 1)
106+
actual = hashlib.pbkdf2_hmac(
107+
"sha256", secret.encode(), salt.encode(), 100000
108+
)
109+
return hmac.compare_digest(actual.hex(), expected)
110+
except Exception:
111+
return False
112+
113+
114+
class RateLimiter:
115+
"""Simple in-memory rate limiter."""
116+
117+
def __init__(self, max_requests: int, window_seconds: int):
118+
self.max_requests = max_requests
119+
self.window_seconds = window_seconds
120+
self._requests: dict[str, list[float]] = {}
121+
122+
def is_allowed(self, key: str) -> bool:
123+
"""Check if request is allowed."""
124+
import time
125+
126+
now = time.time()
127+
128+
if key not in self._requests:
129+
self._requests[key] = []
130+
131+
self._requests[key] = [
132+
ts for ts in self._requests[key] if now - ts < self.window_seconds
133+
]
134+
135+
if len(self._requests[key]) >= self.max_requests:
136+
return False
137+
138+
self._requests[key].append(now)
139+
return True
140+
141+
def reset(self, key: str):
142+
"""Reset rate limit for key."""
143+
self._requests.pop(key, None)
144+
145+
146+
def require_auth(func: Callable) -> Callable:
147+
"""Decorator to require authentication."""
148+
149+
@wraps(func)
150+
async def wrapper(*args, **kwargs):
151+
# TODO: Implement actual auth check
152+
# For now, check if API key is present
153+
api_key = kwargs.get("api_key") or os.getenv("API_KEY")
154+
if not api_key:
155+
from fastapi import HTTPException
156+
157+
raise HTTPException(status_code=401, detail="Authentication required")
158+
return await func(*args, **kwargs)
159+
160+
return wrapper
161+
162+
163+
def sanitize_log_output(data: str | dict) -> str:
164+
"""Remove sensitive data from logs."""
165+
if isinstance(data, dict):
166+
data = str(data)
167+
168+
patterns = [
169+
(r'(api[_-]?key["\s:=]+)["\']?[\w-]+', r"\1***"),
170+
(r'(token["\s:=]+)["\']?[\w-]+', r"\1***"),
171+
(r'(password["\s:=]+)["\']?[\w-]+', r"\1***"),
172+
(r'(secret["\s:=]+)["\']?[\w-]+', r"\1***"),
173+
(r"Bearer\s+[\w-]+", "Bearer ***"),
174+
]
175+
176+
for pattern, replacement in patterns:
177+
data = re.sub(pattern, replacement, data, flags=re.IGNORECASE)
178+
179+
return data

agentic_security/fuzz_chain/chain.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
class FuzzRunnable(Protocol):
99
"""Protocol for objects that can be run in a fuzzing chain."""
1010

11-
async def run(self, **kwargs: Any) -> str:
12-
...
11+
async def run(self, **kwargs: Any) -> str: ...
1312

1413

1514
class FuzzNode:

agentic_security/llm_providers/anthropic_provider.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
def _get_client(self) -> Any:
3737
if self._client is None:
3838
import anthropic
39+
3940
kwargs: dict[str, Any] = {"api_key": self.api_key}
4041
if self.base_url:
4142
kwargs["base_url"] = self.base_url
@@ -45,6 +46,7 @@ def _get_client(self) -> Any:
4546
def _get_async_client(self) -> Any:
4647
if self._async_client is None:
4748
import anthropic
49+
4850
kwargs: dict[str, Any] = {"api_key": self.api_key}
4951
if self.base_url:
5052
kwargs["base_url"] = self.base_url
@@ -95,6 +97,7 @@ def _parse_response(self, response: Any) -> LLMResponse:
9597

9698
def _handle_error(self, e: Exception) -> None:
9799
import anthropic
100+
98101
if isinstance(e, anthropic.RateLimitError):
99102
raise LLMRateLimitError(str(e)) from e
100103
if isinstance(e, anthropic.APIError):

agentic_security/llm_providers/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ class LLMRateLimitError(LLMProviderError):
2020
@dataclass
2121
class LLMMessage:
2222
"""A message in a chat conversation."""
23+
2324
role: str # "system", "user", or "assistant"
2425
content: str
2526

2627

2728
@dataclass
2829
class LLMResponse:
2930
"""Response from an LLM provider."""
31+
3032
content: str
3133
model: str | None = None
3234
finish_reason: str | None = None

agentic_security/llm_providers/factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def _ensure_registered() -> None:
1414
return
1515
from agentic_security.llm_providers.openai_provider import OpenAIProvider
1616
from agentic_security.llm_providers.anthropic_provider import AnthropicProvider
17+
1718
_PROVIDERS["openai"] = OpenAIProvider
1819
_PROVIDERS["anthropic"] = AnthropicProvider
1920

agentic_security/llm_providers/openai_provider.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@ def __init__(
3636
def _get_client(self) -> Any:
3737
if self._client is None:
3838
import openai
39+
3940
self._client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
4041
return self._client
4142

4243
def _get_async_client(self) -> Any:
4344
if self._async_client is None:
4445
import openai
45-
self._async_client = openai.AsyncOpenAI(api_key=self.api_key, base_url=self.base_url)
46+
47+
self._async_client = openai.AsyncOpenAI(
48+
api_key=self.api_key, base_url=self.base_url
49+
)
4650
return self._async_client
4751

4852
@classmethod
@@ -79,6 +83,7 @@ def _parse_response(self, response: Any) -> LLMResponse:
7983

8084
def _handle_error(self, e: Exception) -> None:
8185
import openai
86+
8287
if isinstance(e, openai.RateLimitError):
8388
raise LLMRateLimitError(str(e)) from e
8489
raise LLMProviderError(str(e)) from e

0 commit comments

Comments
 (0)