-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmode_controller.py
More file actions
357 lines (283 loc) · 12.2 KB
/
mode_controller.py
File metadata and controls
357 lines (283 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
"""
Mode Controller Module for Intelligent Navigation Mode Switching.
Manages different operating modes with context-aware behavior.
"""
import time
import re
import config
from object_manager import ObjectManager
try:
import cv2
import numpy as np
CV2_AVAILABLE = True
except ImportError:
cv2 = None
np = None
CV2_AVAILABLE = False
class ModeController:
"""
Manages navigation modes and coordinates system behavior.
"""
def __init__(self):
"""Initialize mode controller."""
self.current_mode = config.DEFAULT_MODE
self.target_object = "phone" # Default tracking target
self.object_manager = ObjectManager()
print(f"[MODE] ModeController initialized | Mode: {self.current_mode}")
def set_mode(self, mode):
"""
Switch to a new mode.
Args:
mode: Mode name from config.NavigationMode
Returns:
True if mode changed, False if invalid mode
"""
if mode not in config.MODE_CONFIGS:
print(f"[ERROR] Invalid mode: {mode}")
return False
if mode == self.current_mode:
print(f"[MODE] Already in {mode} mode")
return True
old_mode = self.current_mode
self.current_mode = mode
# Clear objects when switching modes
self.object_manager.clear()
mode_desc = config.MODE_CONFIGS[mode]["description"]
print(f"[MODE] Mode changed: {old_mode} -> {self.current_mode} ({mode_desc})")
return True
def get_mode_config(self):
"""Get configuration for current mode."""
return config.MODE_CONFIGS.get(self.current_mode, {})
def get_detection_prompt(self):
"""
Get the appropriate detection prompt for current mode.
Returns:
Formatted prompt string
"""
mode_config = self.get_mode_config()
prompt_template = mode_config.get(
"prompt", config.DETECTION_PROMPT_MULTI_OBJECT
)
# Format with target object if needed
if "{target_object}" in prompt_template:
return prompt_template.format(target_object=self.target_object)
return prompt_template
def set_target_object(self, obj_name):
"""
Set the target object to track (for navigation mode).
Args:
obj_name: Object name (e.g., "phone", "person", "door")
"""
self.target_object = obj_name
print(f"[MODE] Target object set to: {obj_name}")
# If in navigation mode, clear current objects to force new detection
if self.current_mode == config.NavigationMode.NAVIGATION:
self.object_manager.clear()
def should_filter_objects(self):
"""Check if current mode requires object filtering."""
mode_config = self.get_mode_config()
return "filter" in mode_config
def get_object_filter(self):
"""Get the label filter for current mode."""
mode_config = self.get_mode_config()
return mode_config.get("filter", [])
def get_max_objects(self):
"""Get maximum number of objects for current mode."""
mode_config = self.get_mode_config()
return mode_config.get("max_objects", config.MAX_TRACKED_OBJECTS)
def get_audio_focus_strategy(self):
"""
Get the audio focus strategy for current mode.
Returns:
Strategy name: "target", "closest", "people", "all"
"""
mode_config = self.get_mode_config()
return mode_config.get("audio_focus", "all")
def process_detections(self, detections, frame):
"""
Process Gemini detection results into tracked objects.
Args:
detections: List of detection dicts from Gemini
frame: Current video frame (for dimensions)
Returns:
Number of objects added
"""
if not detections:
return 0
# SIMPLIFIED: Clear all existing objects and add fresh detections
# This removes the buggy retracking logic that caused stale bounding boxes
# self.object_manager.clear() # REMOVED for persistence
count = 0
frame_height, frame_width = frame.shape[:2]
current_time = time.time()
for det in detections:
try:
# Extract normalized coordinates (0-1000 range)
box_2d = det.get("box_2d", [])
label = det.get("label", "unknown")
if len(box_2d) != 4:
continue
# Convert from normalized (0-1000) to pixel coordinates
y_min, x_min, y_max, x_max = box_2d
x = int(x_min * frame_width / 1000)
y = int(y_min * frame_height / 1000)
w = int((x_max - x_min) * frame_width / 1000)
h = int((y_max - y_min) * frame_height / 1000)
# Validate and clamp bounding box
x = max(0, min(x, frame_width - 1))
y = max(0, min(y, frame_height - 1))
w = max(1, min(w, frame_width - x))
h = max(1, min(h, frame_height - y))
if w < 5 or h < 5:
continue
new_bbox = (x, y, w, h)
# Parse label for context (e.g. "Phone [on table]")
context = None
match = re.search(r"^(.*?)\[(.*?)\]", label)
if match:
label = match.group(1).strip()
context = match.group(2).strip()
# Try to match with existing object
matched = False
new_label_lower = label.lower()
new_label_tokens = set(re.findall(r"\w+", new_label_lower))
for existing_obj in self.object_manager.objects:
existing_label_lower = existing_obj.label.lower()
existing_label_tokens = set(
re.findall(r"\w+", existing_label_lower)
)
labels_match = existing_label_lower == new_label_lower or bool(
existing_label_tokens & new_label_tokens
)
if labels_match:
iou = self.object_manager.compute_iou(
existing_obj.bbox, new_bbox
)
# SMART MERGING:
# If object is currently tracked (not lost), be very conservative about updating
# its position from detection, because detection might be stale (laggy).
# Only update if:
# 1. Object is LOST (we need to find it)
# 2. IoU is very high (it hasn't moved much)
should_update_bbox = False
if existing_obj.is_lost:
# If lost, accept the new detection if it matches reasonably well
if iou > 0.1: # Loose threshold for recovery
should_update_bbox = True
else:
# If currently tracking, only update if it matches VERY well
# This prevents "jumping" back to old positions due to API latency
if iou > 0.6:
should_update_bbox = True
if should_update_bbox:
existing_obj.bbox = new_bbox
# print(f"🔄 Updated object #{existing_obj.id}: {label} (IoU={iou:.2f})")
# ALWAYS update metadata
if iou > 0.1: # If it's likely the same object
existing_obj.last_verified = current_time
existing_obj.context = context # Update context
existing_obj.is_lost = False
existing_obj.lost_time = None
matched = True
break
if not matched:
# Add new object
obj = self.object_manager.add_object(
label, new_bbox, context=context
)
print(
f"[MODE] Added object #{obj.id}: {label} at {new_bbox} (Context: {context})"
)
count += 1
except Exception as e:
print(f"[ERROR] Error processing detection: {e}")
continue
return count
def _convert_bbox(self, box_2d, frame_width=640, frame_height=480):
"""
Convert normalized bounding box to pixel coordinates.
Args:
box_2d: [y_min, x_min, y_max, x_max] normalized 0-1000
frame_width, frame_height: Frame dimensions
Returns:
(x, y, w, h) in pixels
"""
if not box_2d or len(box_2d) != 4:
return None
y_min, x_min, y_max, x_max = box_2d
x1 = int((x_min / 1000) * frame_width)
y1 = int((y_min / 1000) * frame_height)
x2 = int((x_max / 1000) * frame_width)
y2 = int((y_max / 1000) * frame_height)
return (x1, y1, x2 - x1, y2 - y1)
def set_frame_dimensions(self, width, height):
"""Store frame dimensions for bbox conversion."""
self.frame_width = width
self.frame_height = height
def get_primary_object(self):
"""
Get the primary object to focus on based on current mode strategy.
Returns:
TrackedObject or None
"""
strategy = self.get_audio_focus_strategy()
if not self.object_manager.objects:
return None
if strategy == "target":
# Find object matching target label
targets = self.object_manager.get_objects_by_label(self.target_object)
return targets[0] if targets else self.object_manager.objects[0]
elif strategy == "closest":
# Get closest object (largest bbox)
return self.object_manager.get_closest_object(
self.frame_width if hasattr(self, "frame_width") else 640,
self.frame_height if hasattr(self, "frame_height") else 480,
)
elif strategy == "people":
# Get first person
people = self.object_manager.get_objects_by_label("person")
return people[0] if people else None
elif strategy == "all":
# Return first object or None (caller should get all objects separately)
return (
self.object_manager.objects[0] if self.object_manager.objects else None
)
return None
def get_mode_description(self):
"""Get human-readable description of current mode."""
mode_config = self.get_mode_config()
return mode_config.get("description", self.current_mode)
def get_main_threat(self):
"""
Get the object with the highest threat score.
Returns:
TrackedObject or None
"""
if not self.object_manager.objects:
return None
# Sort by threat score descending
sorted_objects = sorted(
self.object_manager.objects, key=lambda o: o.threat_score, reverse=True
)
return sorted_objects[0]
def check_lost_threats(self):
"""
Check if the main threat has been lost for too long.
Returns:
True if a re-scan is needed, False otherwise.
"""
main_threat = self.get_main_threat()
if not main_threat:
return False
# If the main threat is lost
if main_threat.is_lost and main_threat.lost_time:
elapsed = time.time() - main_threat.lost_time
# Only trigger if it's a significant threat (score > 0.3)
if main_threat.threat_score > 0.3 and elapsed > 2.0:
print(
f"[WARNING] Main threat '{main_threat.label}' (Score: {main_threat.threat_score:.2f}) lost for {elapsed:.1f}s. Triggering re-scan."
)
# Remove the stale object so we don't keep checking it
self.object_manager.remove_object(main_threat.id)
return True
return False