-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
63 lines (54 loc) · 1.66 KB
/
config.py
File metadata and controls
63 lines (54 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Dataset paths - Update these to point to your data
# DATA_DIR = (
# "../Datasets/CAPP Dataset/SubjectDependent50PercentOverlap_6sW_Insole_wrist_r"
# )
DATA_DIR = (
"/run/media/rune/Bulk Storage/inmoresentum stuff/capp_temp/"
)
X_TRAIN_PATH = "X_train.txt"
Y_TRAIN_PATH = "y_train.txt"
X_TEST_PATH = "X_test.txt"
Y_TEST_PATH = "y_test.txt"
# Output Directory Names
MODELS_DIR = "trained_models"
RESULTS_DIR = "classification_results"
GEN_TFLM = "gen_deployment_artifacts"
# Model Artifact Filenames
TEACHER_MODEL_NAME = "teacher_dual_cross_tap.keras"
# Student S1 (Standard CNN)
STUDENT_S1_NO_KD_NAME = "student_s1_no_kd.keras"
STUDENT_S1_KD_NAME = "student_s1_kd.keras"
# Student S2 (Efficient Separable CNN)
STUDENT_S2_NO_KD_NAME = "student_s2_no_kd.keras"
STUDENT_S2_KD_NAME = "student_s2_kd.keras"
# Dataset parameters
NUM_CLASSES = 21
TIMESTEPS = 300 # 6 seconds at 50Hz
OVERLAP_SAMPLES = 150
IMU_CHANNELS = 9
PRESSURE_CHANNELS = 16
# Teacher model hyperparameters
TEACHER_LEARNING_RATE = 1e-4
TEACHER_WEIGHT_DECAY = 1e-4
TEACHER_BATCH_SIZE = 32
TEACHER_EPOCHS = 10
TEACHER_PATIENCE = 5 # For ReduceLROnPlateau
TEACHER_LR_FACTOR = 0.5
# Student model hyperparameters without KD (Shared for S1 and S2)
STUDENT_LEARNING_RATE = 1e-4
STUDENT_WEIGHT_DECAY = 1e-4
STUDENT_BATCH_SIZE = 32
STUDENT_EPOCHS = 40
STUDENT_PATIENCE = 5 # For ReduceLROnPlateau
STUDENT_LR_FACTOR = 0.5
# Knowledge Distillation hyperparameters
KD_TEMPERATURE = 3.0
KD_ALPHA = 0.75 # Weight for hard labels (ground truth)
KD_LEARNING_RATE = 5e-5
KD_WEIGHT_DECAY = 1e-4
KD_BATCH_SIZE = 32
KD_EPOCHS = 400
KD_MAX_PATIENCE = 200 # For early stopping
# Validation split
VAL_SPLIT = 0.15
RANDOM_SEED = 42