-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_resnet.py
More file actions
334 lines (268 loc) · 12.2 KB
/
Copy pathinference_resnet.py
File metadata and controls
334 lines (268 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
"""
ASL Real-Time Inference (Elite ResNet - 250 Words)
===================================================
Uses your TFLite model trained on 94k samples with 71% accuracy.
Opens your webcam, extracts MediaPipe landmarks, and predicts ASL signs in real-time.
Usage:
python inference_resnet.py
Keys:
ESC - Exit
SPACE - Add a space to the text
D - Delete last word
"""
import cv2
import numpy as np
import json
import os
import mediapipe as mp
# ============================================================
# 1. CONFIGURATION
# ============================================================
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = os.path.join(SCRIPT_DIR, "asl_model.tflite")
LABEL_PATH = os.path.join(SCRIPT_DIR, "sign_to_prediction_index_map.json")
MAX_FRAMES = 60 # Must match training
NUM_FEATURES = 150 # 75 landmarks × 2 (x, y)
CONFIDENCE_THRESHOLD = 0.35 # Minimum confidence to accept a prediction
BUFFER_SIZE = 3 # Number of consecutive predictions needed for stability
# ============================================================
# 2. LOAD MODEL & LABELS
# ============================================================
print("Loading TFLite model...")
try:
import tflite_runtime.interpreter as tflite
interpreter = tflite.Interpreter(model_path=MODEL_PATH)
except ImportError:
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(f" Input shape: {input_details[0]['shape']}")
print(f" Output shape: {output_details[0]['shape']}")
# Load the word dictionary
with open(LABEL_PATH, 'r') as f:
sign_map = json.load(f)
id_to_sign = {v: k for k, v in sign_map.items()}
print(f" Loaded {len(sign_map)} signs.")
# ============================================================
# 3. LANDMARK EXTRACTION (Matches Training Pipeline Exactly)
# ============================================================
# Training used: left_hand (21), right_hand (21), pose (first 33 mapped to indices 42-74)
# Total: 75 landmarks × 2 (x, y) = 150 features per frame
# Then velocity is appended → 300 features per frame
def extract_landmarks_from_holistic(results):
"""
Extract the same 75 landmarks used during training from MediaPipe Holistic results.
Returns a (75, 2) array with NaN for missing landmarks.
Layout:
[0:21] = left_hand
[21:42] = right_hand
[42:75] = pose (33 landmarks)
"""
data = np.full((75, 2), np.nan)
# Left hand (indices 0-20)
if results.left_hand_landmarks:
for i, lm in enumerate(results.left_hand_landmarks.landmark):
if i < 21:
data[i] = [lm.x, lm.y]
# Right hand (indices 21-41)
if results.right_hand_landmarks:
for i, lm in enumerate(results.right_hand_landmarks.landmark):
if i < 21:
data[21 + i] = [lm.x, lm.y]
# Pose (indices 42-74)
if results.pose_landmarks:
for i, lm in enumerate(results.pose_landmarks.landmark):
if i < 33:
data[42 + i] = [lm.x, lm.y]
return data
def process_frame_buffer(frame_buffer):
"""
Convert a list of (75, 2) arrays into the (60, 300) input the model expects.
Applies nose-centering and velocity computation exactly like training.
"""
# Pad or truncate to MAX_FRAMES
while len(frame_buffer) < MAX_FRAMES:
frame_buffer.append(np.full((75, 2), np.nan))
frames = []
for data in frame_buffer[:MAX_FRAMES]:
# Nose-center the data (nose = pose landmark 0 = index 42)
nose = data[42]
if np.isnan(nose).any():
nose = np.nanmean(data, axis=0)
if np.isnan(nose).any():
nose = np.array([0.0, 0.0])
centered = data - nose
frames.append(centered.flatten()) # (150,)
X = np.nan_to_num(np.array(frames, dtype=np.float32)) # (60, 150)
# Compute velocity (exactly like training)
if len(X) > 1:
velocity = np.diff(X, axis=0, prepend=X[:1])
X = np.concatenate([X, velocity], axis=-1) # (60, 300)
else:
zeros = np.zeros((len(X), NUM_FEATURES), dtype=np.float32)
X = np.concatenate([X, zeros], axis=-1)
return X
def predict(input_data):
"""
Run TFLite inference on a (60, 300) array.
Returns (predicted_word, confidence, top_3_predictions).
"""
input_data = np.expand_dims(input_data, axis=0).astype(np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output = interpreter.get_tensor(output_details[0]['index'])[0]
top_3_idx = np.argsort(output)[-3:][::-1]
top_3 = [(id_to_sign.get(i, "???"), float(output[i])) for i in top_3_idx]
best_idx = top_3_idx[0]
best_word = id_to_sign.get(best_idx, "???")
best_conf = float(output[best_idx])
return best_word, best_conf, top_3
# ============================================================
# 4. MEDIAPIPE SETUP
# ============================================================
mp_holistic = mp.solutions.holistic
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
holistic = mp_holistic.Holistic(
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
model_complexity=1
)
# ============================================================
# 5. MAIN LOOP
# ============================================================
def main():
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
if not cap.isOpened():
print("ERROR: Cannot open camera!")
return
print("\n" + "=" * 60)
print(" ASL REAL-TIME INFERENCE (Elite ResNet - 250 Words)")
print(" Model: 71% Accuracy on 94,477 samples")
print("=" * 60)
print(" SPACE = Add space | D = Delete word | ESC = Exit")
print(" Recording starts when hands/body are detected.")
print("=" * 60 + "\n")
# State variables
frame_buffer = [] # Collects landmark frames
recording = False # Whether we are currently recording a sign
no_hand_counter = 0 # Frames without hands detected
text = "" # The accumulated text
last_prediction = "" # Last predicted word
prediction_buffer = [] # For stability voting
current_display = "" # What to show as the current prediction
current_confidence = 0.0
top_3_display = []
cooldown = 0 # Cooldown after a word is accepted
while True:
ret, frame = cap.read()
if not ret:
break
# --- THE MIRROR FIX ---
# 1. AI processes the RAW frame (Correct Right/Left orientation)
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = holistic.process(rgb)
# 2. USER sees the FLIPPED frame (Natural mirror feel)
display_frame = cv2.flip(frame, 1)
h, w, _ = display_frame.shape
# Check if hands are visible
has_hands = (results.left_hand_landmarks is not None or
results.right_hand_landmarks is not None)
if has_hands and cooldown <= 0:
no_hand_counter = 0
# Extract landmarks for this frame
landmarks = extract_landmarks_from_holistic(results)
frame_buffer.append(landmarks)
recording = True
# Once we have enough frames, start predicting
if len(frame_buffer) >= 15:
input_data = process_frame_buffer(list(frame_buffer))
word, conf, top_3 = predict(input_data)
current_display = word
current_confidence = conf
top_3_display = top_3
if len(frame_buffer) > MAX_FRAMES:
frame_buffer = frame_buffer[-MAX_FRAMES:]
else:
no_hand_counter += 1
if recording and no_hand_counter > 10:
if len(frame_buffer) >= 15:
input_data = process_frame_buffer(list(frame_buffer))
word, conf, top_3 = predict(input_data)
if conf >= CONFIDENCE_THRESHOLD and word != last_prediction:
text += word + " "
last_prediction = word
cooldown = 20
print(f" ✅ Detected: {word} ({conf*100:.1f}%)")
frame_buffer = []
recording = False
current_display = ""
current_confidence = 0.0
top_3_display = []
if cooldown > 0:
cooldown -= 1
# --- DRAW SKELETON ON RAW FRAME ---
if results.left_hand_landmarks:
mp_drawing.draw_landmarks(frame, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
if results.right_hand_landmarks:
mp_drawing.draw_landmarks(frame, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS)
if results.pose_landmarks:
mp_drawing.draw_landmarks(frame, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS)
# --- FLIP THE FRAME FOR THE USER AFTER DRAWING SKELETON ---
display_frame = cv2.flip(frame, 1)
# --- UI DRAWING (on display_frame) ---
cv2.rectangle(display_frame, (0, 0), (w, 40), (30, 30, 30), -1)
cv2.putText(display_frame, "ASL ResNet (250 Words) | SPACE=Space D=Delete ESC=Exit",
(10, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1)
if recording:
cv2.circle(display_frame, (w - 30, 25), 10, (0, 0, 255), -1)
cv2.putText(display_frame, f"REC ({len(frame_buffer)} frames)",
(w - 180, 32), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
if current_display:
color = (0, 200, 0) if current_confidence >= CONFIDENCE_THRESHOLD else (0, 165, 255)
cv2.rectangle(display_frame, (10, h - 200), (350, h - 120), (40, 40, 40), -1)
cv2.rectangle(display_frame, (10, h - 200), (350, h - 120), color, 2)
cv2.putText(display_frame, current_display.upper(),
(20, h - 145), cv2.FONT_HERSHEY_SIMPLEX, 1.5, color, 3)
cv2.putText(display_frame, f"{current_confidence*100:.1f}%",
(20, h - 125), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (180, 180, 180), 1)
if top_3_display:
cv2.rectangle(display_frame, (10, h - 115), (350, h - 45), (40, 40, 40), -1)
for i, (word, score) in enumerate(top_3_display):
y_pos = h - 100 + (i * 22)
bar_width = int(score * 200)
cv2.rectangle(display_frame, (120, y_pos - 12), (120 + bar_width, y_pos + 4), (0, 150, 0), -1)
cv2.putText(display_frame, f"{i+1}. {word}", (15, y_pos),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (220, 220, 220), 1)
cv2.rectangle(display_frame, (0, h - 40), (w, h), (50, 50, 50), -1)
display_text = text if len(text) <= 60 else "..." + text[-57:]
cv2.putText(display_frame, f"Text: {display_text}",
(10, h - 12), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
cv2.imshow("ASL Real-Time Inference", display_frame)
# --- Key Handling ---
key = cv2.waitKey(1) & 0xFF
if key == 27: # ESC
break
elif key == 32: # SPACE
text += " "
elif key == ord('d') or key == 8: # D or Backspace
# Delete last word
text = text.rstrip()
if " " in text:
text = text[:text.rfind(" ")] + " "
else:
text = ""
last_prediction = ""
print(" 🗑️ Deleted last word.")
# Cleanup
cap.release()
cv2.destroyAllWindows()
holistic.close()
print(f"\nFinal Text: {text}")
print("Goodbye! 👋")
if __name__ == "__main__":
main()