-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_combined_dataset.py
More file actions
98 lines (80 loc) · 4.5 KB
/
test_combined_dataset.py
File metadata and controls
98 lines (80 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""Test the combined dataset loading functionality."""
from dataset_loader import load_dataset_as_dspy_examples, load_combined_dataset
def test_dataset_loading():
"""Test that datasets load correctly and have expected sizes."""
print("=" * 80)
print("Testing Combined Dataset Loading")
print("=" * 80)
# Load individual datasets as DSPy examples
print("\n1. Loading individual datasets as DSPy examples...")
legitimate = load_dataset_as_dspy_examples("legitimate.jsonl")
policy_violations = load_dataset_as_dspy_examples("content_policy_violation.jsonl")
readonly_violations = load_dataset_as_dspy_examples("read_only_violation.jsonl")
print(f" Legitimate examples: {len(legitimate)}")
print(f" Content policy violations: {len(policy_violations)}")
print(f" Read-only violations: {len(readonly_violations)}")
# Verify expected sizes
assert len(legitimate) == 100, f"Expected 100 legitimate examples, got {len(legitimate)}"
assert len(policy_violations) == 20, f"Expected 20 policy violations, got {len(policy_violations)}"
assert len(readonly_violations) == 20, f"Expected 20 read-only violations, got {len(readonly_violations)}"
print(" ✓ All datasets loaded with expected sizes")
# Test DSPy Example structure
print("\n2. Testing DSPy Example structure...")
example = legitimate[0]
assert hasattr(example, 'natural_language_query'), "Missing natural_language_query field"
assert hasattr(example, 'sql_query'), "Missing sql_query field"
print(f" Example query: {example.natural_language_query[:60]}...")
print(f" Example SQL: {example.sql_query[:60]}...")
print(" ✓ DSPy Examples have correct structure")
# Test development mode
print("\n3. Testing development mode (5+5+5 train, 5+5+5 val)...")
train_dev, val_dev = load_combined_dataset(development_mode=True)
assert len(train_dev) == 15, f"Expected 15 training examples, got {len(train_dev)}"
assert len(val_dev) == 15, f"Expected 15 validation examples, got {len(val_dev)}"
print(" ✓ Development mode split correct")
# Test full mode
print("\n4. Testing full mode (interleaved training)...")
train_full, val_full = load_combined_dataset(development_mode=False)
# Expected counts:
# - Legitimate: 75 train, 25 val
# - Policy violations: 15 train, 5 val
# - Read-only violations: 15 train, 5 val
# - Training interleaved: 75 × 3 = 225 (leg, pol, ro, leg, pol, ro, ...)
# - Total train: 225
# - Total val: 25 + 5 + 5 = 35
assert len(train_full) == 225, f"Expected 225 training examples, got {len(train_full)}"
assert len(val_full) == 35, f"Expected 35 validation examples, got {len(val_full)}"
print(" ✓ Full mode split correct")
# Test interleaving pattern in training (every 3rd should repeat the pattern)
print("\n5. Testing interleaving pattern...")
# Check first 9 examples follow pattern: leg, pol, ro, leg, pol, ro, leg, pol, ro
for i in range(0, min(9, len(train_full)), 3):
# Can't easily check type without executing, but verify we have 3 different examples
assert train_full[i] != train_full[i+1], "Training should interleave different types"
assert train_full[i+1] != train_full[i+2], "Training should interleave different types"
print(" ✓ Interleaving pattern correct")
# Test violation SQL responses
print("\n6. Testing violation responses...")
policy_example = policy_violations[0]
readonly_example = readonly_violations[0]
assert policy_example.sql_query == "SELECT 'query violates content policy'", \
f"Policy violation has wrong SQL: {policy_example.sql_query}"
assert readonly_example.sql_query == "SELECT 'database is read-only'", \
f"Read-only violation has wrong SQL: {readonly_example.sql_query}"
print(" ✓ All violation responses correct")
# Show sample examples
print("\n7. Sample examples:")
print("\n Legitimate:")
print(f" Q: {legitimate[0].natural_language_query}")
print(f" SQL: {legitimate[0].sql_query[:80]}...")
print("\n Content Policy Violation:")
print(f" Q: {policy_violations[0].natural_language_query}")
print(f" SQL: {policy_violations[0].sql_query}")
print("\n Read-Only Violation:")
print(f" Q: {readonly_violations[0].natural_language_query}")
print(f" SQL: {readonly_violations[0].sql_query}")
print("\n" + "=" * 80)
print("All tests passed! ✓")
print("=" * 80)
if __name__ == "__main__":
test_dataset_loading()