-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup_data.py
More file actions
76 lines (63 loc) · 2.15 KB
/
setup_data.py
File metadata and controls
76 lines (63 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python3
"""Setup script for data directories and environment verification."""
import json
from pathlib import Path
def create_directories():
dirs = [
"data/prompts", "data/results/behavioral", "data/results/binding",
"data/results/causal", "data/results/few_shot", "data/tokenization",
"figures",
]
for d in dirs:
Path(d).mkdir(parents=True, exist_ok=True)
print(f" Created: {d}")
def verify_prompts():
p = Path("data/prompts/pilot_terms.jsonl")
if p.exists():
with open(p) as f:
n = sum(1 for _ in f)
print(f" OK: {p} ({n} prompts)")
else:
print(f" MISSING: {p}")
def check_dependencies():
required = [
("torch", "torch"), ("transformers", "transformers"),
("transformer_lens", "transformer-lens"), ("numpy", "numpy"),
("scipy", "scipy"), ("pandas", "pandas"), ("tqdm", "tqdm"),
("matplotlib", "matplotlib"),
]
for mod_name, pip_name in required:
try:
mod = __import__(mod_name)
ver = getattr(mod, "__version__", "?")
print(f" OK: {pip_name} ({ver})")
except ImportError:
print(f" MISSING: pip install {pip_name}")
def check_gpu():
try:
import torch
if torch.cuda.is_available():
name = torch.cuda.get_device_name(0)
mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f" GPU: {name} ({mem:.1f} GB)")
else:
print(" No CUDA GPU (CPU mode will be slow)")
except Exception as e:
print(f" GPU check failed: {e}")
def main():
print("=" * 60)
print("Attention Binding - Reproducibility Setup")
print("=" * 60)
print("\n[1/4] Creating directories...")
create_directories()
print("\n[2/4] Checking dependencies...")
check_dependencies()
print("\n[3/4] Checking GPU...")
check_gpu()
print("\n[4/4] Verifying prompts...")
verify_prompts()
print("\nSetup complete! Next:")
print(" python src/tokenization_audit.py")
print(" python src/eval_behavior.py 160m step120000")
if __name__ == "__main__":
main()