Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions agentic_security/probe_data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,13 @@ def load_jailbreak_v28k() -> ProbeDataset:
@cache_to_disk()
def load_local_csv() -> ProbeDataset:
"""Load prompts from local CSV files."""
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")

prompts = []
for file in csv_files:
try:
df = pd.read_csv(file)
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
if "prompt" in df.columns:
prompts.extend(df["prompt"].tolist())
else:
Expand All @@ -270,7 +270,7 @@ def load_csv(file: str) -> ProbeDataset:
"""Load prompts from local CSV files."""
prompts = []
try:
df = pd.read_csv(file)
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
prompts = df["prompt"].tolist()
if "prompt" in df.columns:
prompts.extend(df["prompt"].tolist())
Expand All @@ -284,14 +284,14 @@ def load_csv(file: str) -> ProbeDataset:
@cache_to_disk(1)
def load_local_csv_files() -> list[ProbeDataset]:
"""Load prompts from local CSV files and return a list of ProbeDataset objects."""
csv_files = [f for f in os.listdir(".") if f.endswith(".csv")]
csv_files = [f for f in os.listdir("./datasets") if f.endswith(".csv")]
logger.info(f"Found {len(csv_files)} CSV files: {csv_files}")

datasets = []

for file in csv_files:
try:
df = pd.read_csv(file)
df = pd.read_csv(os.path.join("./datasets", file), encoding_errors="ignore")
if "prompt" in df.columns:
prompts = df["prompt"].tolist()
datasets.append(create_probe_dataset(file, prompts, {"src": file}))
Expand Down Expand Up @@ -336,12 +336,18 @@ def transform(self) -> Iterator[ProbeDataset]:
)
continue

size = len(dataset.prompts)
# Convert async generator (if any) to a list
prompts = list(dataset.prompts) if hasattr(dataset.prompts, '__iter__') else []
size = len(prompts)

if size == 0:
continue

sample_size = max(1, int(size * 0.25)) # At least 1 sample

for name, fn in self.TRANSFORMATIONS.items():
logger.info(f"Applying {name} to {dataset.dataset_name}")
sampled = random.sample(dataset.prompts, min(sample_size, size))
sampled = random.sample(prompts, min(sample_size, size))
try:
transformed = [
f"Follow the instructions in {name}: {fn(str(p))}"
Expand Down
Loading