-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_deployment_prep.py
More file actions
143 lines (118 loc) · 4.21 KB
/
run_deployment_prep.py
File metadata and controls
143 lines (118 loc) · 4.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Deployment Preparation Script
Orchestrates TFLite conversion,
quantization (Int8, FP16, and Mixed Precision),
and C-header generation for ESP32-S3 deployment.
"""
import tensorflow as tf
import os
from config import *
from models.utils import prepare_data
from models.deployment_prep_utils import (
convert_to_tflite,
generate_c_array,
evaluate_tflite_inference
)
from models.analysis import wrap_with_softmax
def process_model_deployment(model_path, model_name, X_calib, X_test, y_test, output_dir):
"""
Helper to run full deployment pipeline for a single model architecture.
"""
print(f"\nProcessing Deployment for: {model_name}")
print("-" * 40)
if not os.path.exists(model_path):
print(f"Skipping {model_name}: Model file not found at {model_path}")
return []
# Load and prep model
keras_model = tf.keras.models.load_model(model_path)
deployment_model = wrap_with_softmax(keras_model)
results = []
# Define Configurations
# CNN-S1 gets full quantization suite; CNN-S2 no quantization (per paper)
if "s1" in model_name.lower():
configs = [
("fp32", "float32"),
("fp16", "float16"),
("mixed", "mixed"),
("int8", "int8")
]
else:
configs = [("fp32", "float32")]
for suffix, quant_mode in configs:
print(f" -> Converting {quant_mode}...")
# 1. Convert
try:
tflite_content = convert_to_tflite(
deployment_model,
quantization=quant_mode,
calibration_data=X_calib
)
except Exception as e:
print(f" Error during conversion: {e}")
continue
# 2. Save TFLite
tflite_filename = f"{model_name}_{suffix}.tflite"
tflite_path = os.path.join(output_dir, tflite_filename)
with open(tflite_path, "wb") as f:
f.write(tflite_content)
# 3. Evaluate (Sanity Check)
acc = evaluate_tflite_inference(tflite_content, X_test, y_test)
# 4. Generate C Header (using xxd)
header_filename = f"{model_name}_{suffix}.h"
header_path = os.path.join(output_dir, header_filename)
# Create C variable name: g_modelname_suffix (e.g., g_cnn_s1_int8)
var_name = f"g_{model_name}_{suffix}".replace("-", "_")
try:
generate_c_array(tflite_path, header_path, var_name)
header_status = "Generated"
except RuntimeError as e:
header_status = "FAILED (Missing xxd)"
print(f" [Warning] {e}")
# Log
size_kb = os.path.getsize(tflite_path) / 1024
results.append({
"Variant": f"{model_name}-{suffix}",
"Size (KB)": size_kb,
"Accuracy": acc,
"Header": header_status
})
return results
def main():
os.makedirs(GEN_TFLM, exist_ok=True)
# Load Data
print("Loading datasets...")
(X_train_imu, _, _, _, _, _, X_test_imu, _, y_test) = prepare_data(
os.path.join(DATA_DIR, X_TRAIN_PATH),
os.path.join(DATA_DIR, Y_TRAIN_PATH),
os.path.join(DATA_DIR, X_TEST_PATH),
os.path.join(DATA_DIR, Y_TEST_PATH)
)
all_results = []
# Process CNN-S1 (Standard)
all_results.extend(process_model_deployment(
model_path=os.path.join(MODELS_DIR, STUDENT_S1_KD_NAME),
model_name="cnn-s1",
X_calib=X_train_imu, # Used for Int8 calibration
X_test=X_test_imu,
y_test=y_test,
output_dir=GEN_TFLM
))
# Process CNN-S2 (Efficient/Separable)
all_results.extend(process_model_deployment(
model_path=os.path.join(MODELS_DIR, STUDENT_S2_KD_NAME),
model_name="cnn-s2",
X_calib=X_train_imu,
X_test=X_test_imu,
y_test=y_test,
output_dir=GEN_TFLM
))
# Final Summary
print("\n" + "=" * 65)
print(f"{'Variant':<20} | {'Size (KB)':<10} | {'Accuracy':<10} | {'Header':<10}")
print("-" * 65)
for r in all_results:
print(f"{r['Variant']:<20} | {r['Size (KB)']:<10.2f} | {r['Accuracy']:.2%} | {r['Header']}")
print("-" * 65)
print(f"Artifacts saved to: {os.path.abspath(GEN_TFLM)}")
if __name__ == "__main__":
main()