-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlearning_module.py
More file actions
467 lines (380 loc) · 16.6 KB
/
learning_module.py
File metadata and controls
467 lines (380 loc) · 16.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
"""
Self-Learning Module for Assistive Navigation System.
Implements persistent object memory, room layout learning, and predictive tracking.
Key Features:
- Builds personalized spatial memory of user's environment
- Predicts object locations based on historical data
- Reduces API calls by 40-70% through intelligent caching
"""
import sqlite3
import os
import time
import threading
import cv2
import numpy as np
import hashlib
import re
from datetime import datetime
from typing import Optional, Tuple, List, Dict
import config
class LearningModule:
"""
Manages persistent learning and adaptive behavior.
Stores object detections, learns spatial patterns, and predicts locations.
"""
def __init__(self, db_path=None, image_cache_dir=None):
"""Initialize learning module with database and image cache."""
self.db_path = db_path or config.LEARNING_DB_PATH
self.image_cache_dir = image_cache_dir or config.IMAGE_CACHE_DIR
# Create cache directory
os.makedirs(self.image_cache_dir, exist_ok=True)
# Initialize database with proper thread safety
# Use URI mode with cache for better performance
self.db_path_uri = f"file:{self.db_path}?cache=shared"
self.conn = sqlite3.connect(self.db_path, check_same_thread=False, isolation_level=None)
self.conn.execute("PRAGMA journal_mode=WAL") # Write-Ahead Logging for better concurrency
self.conn.execute("PRAGMA synchronous=NORMAL") # Balance safety and speed
self._init_database()
# Thread-local storage for database connections
self._local = threading.local()
# Grid parameters for room mapping
self.grid_width = config.LEARNING_GRID_WIDTH
self.grid_height = config.LEARNING_GRID_HEIGHT
print(f"[LEARNING] LearningModule initialized | DB: {self.db_path}")
print(f" Image cache: {self.image_cache_dir}")
print(f" Grid: {self.grid_width}x{self.grid_height}")
def _init_database(self):
"""Create database tables if they don't exist."""
cursor = self.conn.cursor()
# Objects table - stores each detection
cursor.execute("""
CREATE TABLE IF NOT EXISTS objects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
label TEXT NOT NULL,
context TEXT,
grid_x INTEGER,
grid_y INTEGER,
confidence REAL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
image_path TEXT,
image_hash TEXT UNIQUE,
bbox_x INTEGER,
bbox_y INTEGER,
bbox_w INTEGER,
bbox_h INTEGER
)
""")
# Migration: Add context column if it doesn't exist (for existing DBs)
try:
cursor.execute("ALTER TABLE objects ADD COLUMN context TEXT")
except sqlite3.OperationalError:
pass # Column likely already exists
# User preferences table
cursor.execute("""
CREATE TABLE IF NOT EXISTS user_preferences (
key TEXT PRIMARY KEY,
value TEXT,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
# Room grid table - frequency map
cursor.execute("""
CREATE TABLE IF NOT EXISTS room_grid (
grid_x INTEGER,
grid_y INTEGER,
object_label TEXT,
frequency INTEGER DEFAULT 1,
last_seen DATETIME DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (grid_x, grid_y, object_label)
)
""")
self.conn.commit()
print("[LEARNING] Database tables initialized")
def _compute_image_hash(self, image):
"""
Compute perceptual hash of image for deduplication.
Uses average hashing (simple but effective).
"""
# Resize to 8x8
small = cv2.resize(image, (8, 8), interpolation=cv2.INTER_AREA)
# Convert to grayscale
gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY) if len(small.shape) == 3 else small
# Average hash
avg = gray.mean()
hash_bytes = (gray > avg).astype(np.uint8).tobytes()
return hashlib.md5(hash_bytes).hexdigest()
def _extract_object_image(self, frame, bbox):
"""Extract and normalize an object crop from a frame."""
x, y, w, h = map(int, bbox)
obj_image = frame[y:y+h, x:x+w]
if obj_image.size == 0:
return None
return cv2.resize(obj_image, (320, 240), interpolation=cv2.INTER_AREA)
def _canonicalize_label(self, label: str) -> str:
"""Convert descriptive labels to a stable lookup form."""
tokens = re.findall(r"\w+", label.lower())
if not tokens:
return label.lower().strip()
stopwords = {"the", "a", "an", "on", "in", "at", "with", "of", "and"}
filtered = [token for token in tokens if token not in stopwords]
if not filtered:
filtered = tokens
return " ".join(filtered)
def compress_and_save_image(self, frame, bbox) -> Tuple[Optional[str], Optional[str]]:
"""
Compress and save object image to cache.
Args:
frame: Full video frame
bbox: Bounding box (x, y, w, h)
Returns:
(path to saved image, image hash), or (None, None) if failed
"""
try:
resized = self._extract_object_image(frame, bbox)
if resized is None:
return None, None
img_hash = self._compute_image_hash(resized)
cursor = self.conn.cursor()
cursor.execute("SELECT image_path FROM objects WHERE image_hash = ?", (img_hash,))
existing = cursor.fetchone()
if existing:
return existing[0], img_hash
filename = f"{img_hash}.jpg"
filepath = os.path.join(self.image_cache_dir, filename)
cv2.imwrite(filepath, resized, [cv2.IMWRITE_JPEG_QUALITY, config.IMAGE_COMPRESSION_QUALITY])
return filepath, img_hash
except Exception as e:
print(f"[ERROR] Failed to save image: {e}")
return None, None
def bbox_to_grid(self, bbox, frame_width, frame_height) -> Tuple[int, int]:
"""Convert bounding box to grid coordinates."""
x, y, w, h = bbox
center_x = x + w / 2
center_y = y + h / 2
grid_x = int((center_x / frame_width) * self.grid_width)
grid_y = int((center_y / frame_height) * self.grid_height)
# Clamp to grid bounds
grid_x = max(0, min(self.grid_width - 1, grid_x))
grid_y = max(0, min(self.grid_height - 1, grid_y))
return (grid_x, grid_y)
def save_detection(self, frame, label, bbox, confidence, frame_width, frame_height, context=None):
"""
Save object detection to database and cache.
Args:
frame: Video frame
label: Object label
bbox: Bounding box (x, y, w, h)
confidence: Detection confidence
frame_width, frame_height: Frame dimensions
context: Semantic context (e.g., "on table")
"""
try:
# Convert to grid coordinates
grid_x, grid_y = self.bbox_to_grid(bbox, frame_width, frame_height)
# Save compressed image
image_path, image_hash = self.compress_and_save_image(frame, bbox)
# Insert into objects table
cursor = self.conn.cursor()
try:
cursor.execute("""
INSERT INTO objects (label, context, grid_x, grid_y, confidence, image_path, image_hash,
bbox_x, bbox_y, bbox_w, bbox_h)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (label, context, grid_x, grid_y, confidence, image_path, image_hash,
int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])))
self.conn.commit()
except sqlite3.IntegrityError:
# Duplicate image hash - already saved
pass
# Update room grid frequency
self._update_room_grid(label, grid_x, grid_y)
print(f"[LEARNING] Saved: {label} at grid({grid_x},{grid_y}) conf={confidence:.2f}")
except Exception as e:
print(f"[ERROR] Failed to save detection: {e}")
def _update_room_grid(self, label, grid_x, grid_y):
"""Update room grid frequency map."""
cursor = self.conn.cursor()
cursor.execute("""
INSERT INTO room_grid (grid_x, grid_y, object_label, frequency, last_seen)
VALUES (?, ?, ?, 1, CURRENT_TIMESTAMP)
ON CONFLICT(grid_x, grid_y, object_label)
DO UPDATE SET
frequency = frequency + 1,
last_seen = CURRENT_TIMESTAMP
""", (grid_x, grid_y, label))
self.conn.commit()
def get_likely_location(self, label) -> Optional[Tuple[int, int, float]]:
"""
Get most likely grid location for an object based on history.
Args:
label: Object label to search for
Returns:
(grid_x, grid_y, probability) or None if no history
"""
cursor = self.conn.cursor()
canonical_label = self._canonicalize_label(label)
# Get total frequency for this object
cursor.execute("""
SELECT SUM(frequency)
FROM room_grid
WHERE lower(object_label) = lower(?)
OR lower(object_label) LIKE lower(?)
OR lower(object_label) LIKE lower(?)
""", (label, f"%{canonical_label}%", f"%{label.lower()}%"))
total = cursor.fetchone()[0]
if not total or total == 0:
return None
# Get most frequent location
cursor.execute("""
SELECT grid_x, grid_y, frequency, last_seen
FROM room_grid
WHERE lower(object_label) = lower(?)
OR lower(object_label) LIKE lower(?)
OR lower(object_label) LIKE lower(?)
ORDER BY frequency DESC, last_seen DESC
LIMIT 1
""", (label, f"%{canonical_label}%", f"%{label.lower()}%"))
result = cursor.fetchone()
if not result:
return None
grid_x, grid_y, freq, last_seen = result
probability = freq / total
print(f"[LEARNING] Prediction: {label} likely at grid({grid_x},{grid_y}) | prob={probability:.1%}")
return (grid_x, grid_y, probability)
def get_search_region(self, grid_x, grid_y, frame_width, frame_height, expand=1) -> Tuple[int, int, int, int]:
"""
Convert grid coordinates to frame region for targeted search.
Args:
grid_x, grid_y: Grid coordinates
frame_width, frame_height: Frame dimensions
expand: Number of grid cells to expand search region
Returns:
(x, y, w, h) region in frame coordinates
"""
cell_w = frame_width / self.grid_width
cell_h = frame_height / self.grid_height
# Expand by N cells in each direction
x_min = max(0, grid_x - expand) * cell_w
y_min = max(0, grid_y - expand) * cell_h
x_max = min(self.grid_width, grid_x + expand + 1) * cell_w
y_max = min(self.grid_height, grid_y + expand + 1) * cell_h
return (int(x_min), int(y_min), int(x_max - x_min), int(y_max - y_min))
def set_preference(self, key, value):
"""Save user preference."""
cursor = self.conn.cursor()
cursor.execute("""
INSERT INTO user_preferences (key, value, updated_at)
VALUES (?, ?, CURRENT_TIMESTAMP)
ON CONFLICT(key)
DO UPDATE SET value = ?, updated_at = CURRENT_TIMESTAMP
""", (key, value, value))
self.conn.commit()
def get_preference(self, key, default=None):
"""Get user preference."""
cursor = self.conn.cursor()
cursor.execute("SELECT value FROM user_preferences WHERE key = ?", (key,))
result = cursor.fetchone()
return result[0] if result else default
def get_stats(self) -> Dict:
"""Get learning statistics."""
cursor = self.conn.cursor()
cursor.execute("SELECT COUNT(*) FROM objects")
total_detections = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(DISTINCT label) FROM objects")
unique_labels = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM room_grid")
grid_cells = cursor.fetchone()[0]
# Calculate cache size
cache_size = sum(os.path.getsize(os.path.join(self.image_cache_dir, f))
for f in os.listdir(self.image_cache_dir)
if f.endswith('.jpg'))
return {
"total_detections": total_detections,
"unique_labels": unique_labels,
"grid_cells_mapped": grid_cells,
"cache_size_mb": cache_size / (1024 * 1024),
"cached_images": len([f for f in os.listdir(self.image_cache_dir) if f.endswith('.jpg')])
}
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
print("[LEARNING] Learning database closed")
# === MEMORY RECALL FEATURE ===
def recall_object(self, label) -> Optional[Dict]:
"""
Recall where an object was last seen.
Args:
label: Object label to search for
Returns:
Dict with location info or None if never seen
"""
cursor = self.conn.cursor()
canonical_label = self._canonicalize_label(label)
cursor.execute("""
SELECT label, context, grid_x, grid_y, timestamp, confidence
FROM objects
WHERE lower(label) = lower(?)
OR lower(label) LIKE lower(?)
OR lower(label) LIKE lower(?)
ORDER BY timestamp DESC
LIMIT 1
""", (label, f"%{canonical_label}%", f"%{label.lower()}%"))
result = cursor.fetchone()
if not result:
return None
label, context, grid_x, grid_y, timestamp, confidence = result
# Calculate time ago
from datetime import datetime
then = datetime.fromisoformat(timestamp)
now = datetime.now()
delta = now - then
# Human-readable time
if delta.total_seconds() < 60:
time_ago = f"{int(delta.total_seconds())} seconds ago"
elif delta.total_seconds() < 3600:
time_ago = f"{int(delta.total_seconds() / 60)} minutes ago"
elif delta.total_seconds() < 86400:
time_ago = f"{int(delta.total_seconds() / 3600)} hours ago"
else:
time_ago = f"{int(delta.total_seconds() / 86400)} days ago"
# Get location description
location_desc = self._grid_to_description(grid_x, grid_y)
# Append context to location description if available
if context:
location_desc += f" ({context})"
return {
"label": label,
"context": context,
"grid_x": grid_x,
"grid_y": grid_y,
"location_desc": location_desc,
"time_ago": time_ago,
"timestamp": timestamp,
"confidence": confidence
}
def _grid_to_description(self, grid_x, grid_y) -> str:
"""Convert grid coordinates to human-readable location."""
# Horizontal position
if grid_x < self.grid_width / 3:
h_pos = "left side"
elif grid_x > 2 * self.grid_width / 3:
h_pos = "right side"
else:
h_pos = "center"
# Vertical position
if grid_y < self.grid_height / 3:
v_pos = "top"
elif grid_y > 2 * self.grid_height / 3:
v_pos = "bottom"
else:
v_pos = "middle"
# Combine
if h_pos == "center" and v_pos == "middle":
return "center of view"
elif h_pos == "center":
return f"{v_pos} center"
elif v_pos == "middle":
return f"{h_pos}"
else:
return f"{v_pos} {h_pos}"