Skip to content

Commit f8ec976

Browse files
rohan-pandeyyrahulharpal1603
authored andcommitted
fix synchronize session tensor access and model deletion lifecycle
- snapshot tensor names to avoid close() invalidating session.run() inputs. - lock model deletion to prevent delete-while-in-use races
1 parent cd7e968 commit f8ec976

6 files changed

Lines changed: 89 additions & 37 deletions

File tree

backend/app/models/FaceNet.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,17 @@ def __init__(self, model_path):
2525
self.output_tensor_name: str | None = None
2626
self._lock = threading.Lock()
2727

28-
def get_session(self) -> onnxruntime.InferenceSession:
28+
def get_session(self) -> tuple[onnxruntime.InferenceSession, str, str]:
2929
session = self._session
3030
if session is not None:
31-
return session
31+
input_name = self.input_tensor_name
32+
output_name = self.output_tensor_name
33+
if input_name is None or output_name is None:
34+
raise RuntimeError(
35+
f"Model session for '{self._model_key}' was closed while "
36+
"get_session() was executing."
37+
)
38+
return session, input_name, output_name
3239

3340
with self._lock:
3441
if self._session is None:
@@ -50,19 +57,31 @@ def get_session(self) -> onnxruntime.InferenceSession:
5057
self.model_path, providers=ONNX_util_get_execution_providers()
5158
)
5259
if self._model_key is not None and not self._session_registered:
53-
mark_model_session_active(self._model_key)
60+
try:
61+
mark_model_session_active(self._model_key)
62+
except RuntimeError:
63+
self._session = None
64+
raise
5465
self._session_registered = True
5566
self.input_tensor_name = self._session.get_inputs()[0].name
5667
self.output_tensor_name = self._session.get_outputs()[0].name
5768

5869
session = self._session
70+
input_name = self.input_tensor_name
71+
output_name = self.output_tensor_name
72+
73+
if session is None or input_name is None or output_name is None:
74+
raise RuntimeError(
75+
f"Model session for '{self._model_key}' was closed while "
76+
"get_session() was executing."
77+
)
5978

60-
return session
79+
return session, input_name, output_name
6180

6281
def get_embedding(self, preprocessed_image):
63-
session = self.get_session()
82+
session, input_tensor_name, output_tensor_name = self.get_session()
6483
result = session.run(
65-
[self.output_tensor_name], {self.input_tensor_name: preprocessed_image}
84+
[output_tensor_name], {input_tensor_name: preprocessed_image}
6685
)[0]
6786
embedding = result[0]
6887
return FaceNet_util_normalize_embedding(embedding)

backend/app/models/YOLO.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(self, path, conf_threshold=0.7, iou_threshold=0.5):
3030
self.iou_threshold = iou_threshold
3131
self._session = None
3232
import threading
33+
3334
self._lock = threading.Lock()
3435

3536
def get_session(self):
@@ -47,7 +48,9 @@ def get_session(self):
4748
if spec["filename"] in self.model_path:
4849
model_key = key
4950
break
50-
model_name = model_key if model_key else os.path.basename(self.model_path)
51+
model_name = (
52+
model_key if model_key else os.path.basename(self.model_path)
53+
)
5154
raise RuntimeError(
5255
f"Model '{model_name}' is not installed. "
5356
"Please install it from Settings → AI Models before using this feature."
@@ -57,7 +60,11 @@ def get_session(self):
5760
self.model_path, providers=ONNX_util_get_execution_providers()
5861
)
5962
if self._model_key is not None and not self._session_registered:
60-
mark_model_session_active(self._model_key)
63+
try:
64+
mark_model_session_active(self._model_key)
65+
except RuntimeError:
66+
self._session = None
67+
raise
6168
self._session_registered = True
6269
# Initialize model info once session is created
6370
self.get_input_details()
@@ -95,9 +102,7 @@ def inference(self, input_tensor, session=None):
95102
start = time.perf_counter()
96103
if session is None:
97104
session = self.get_session()
98-
outputs = session.run(
99-
self.output_names, {self.input_names[0]: input_tensor}
100-
)
105+
outputs = session.run(self.output_names, {self.input_names[0]: input_tensor})
101106
logger.debug("Inference completed in %.4fs", time.perf_counter() - start)
102107
return outputs
103108

backend/app/models/model_registry.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from typing import TypedDict, Literal
4-
import json
54
import os
65
import sys
76
from platformdirs import user_data_dir
@@ -79,10 +78,10 @@ class ModelSpec(TypedDict):
7978
}
8079

8180
TIER_MODELS: dict[str, list[str]] = {
82-
"nano": ["yolo_nano", "yolo_nano_face"],
83-
"small": ["yolo_small", "yolo_small_face"],
81+
"nano": ["yolo_nano", "yolo_nano_face"],
82+
"small": ["yolo_small", "yolo_small_face"],
8483
"medium": ["yolo_medium", "yolo_medium_face"],
85-
"required": ["facenet"], # Required model; not user-selectable
84+
"required": ["facenet"], # Required model; not user-selectable
8685
}
8786

8887
USER_DATA_MODELS = os.path.join(user_data_dir("PictoPy"), "models")
@@ -91,7 +90,7 @@ class ModelSpec(TypedDict):
9190

9291
def ensure_model_exports_directory() -> None:
9392
"""Create the active model exports directory if it does not exist."""
94-
if getattr(sys, 'frozen', False):
93+
if getattr(sys, "frozen", False):
9594
os.makedirs(USER_DATA_MODELS, exist_ok=True)
9695
else:
9796
os.makedirs(LOCAL_ONNX_EXPORTS, exist_ok=True)
@@ -100,11 +99,11 @@ def ensure_model_exports_directory() -> None:
10099
def get_model_path(key: str) -> str:
101100
filename = MODEL_REGISTRY[key]["filename"]
102101
ensure_model_exports_directory()
103-
102+
104103
# In production (compiled by PyInstaller), use the platform-appropriate user data directory.
105-
if getattr(sys, 'frozen', False):
104+
if getattr(sys, "frozen", False):
106105
return os.path.normpath(os.path.join(USER_DATA_MODELS, filename))
107-
106+
108107
# In development, strictly use the local repo folder
109108
return os.path.normpath(os.path.join(LOCAL_ONNX_EXPORTS, filename))
110109

@@ -115,4 +114,3 @@ def get_model_key_from_path(model_path: str) -> str | None:
115114
if spec["filename"].lower() == target_filename:
116115
return key
117116
return None
118-

backend/app/models/session_registry.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
import threading
44

55
_active_sessions: dict[str, int] = {}
6+
_models_pending_deletion: set[str] = set()
67
_registry_lock = threading.Lock()
78

89

910
def mark_model_session_active(model_key: str) -> None:
1011
with _registry_lock:
12+
if model_key in _models_pending_deletion:
13+
raise RuntimeError(
14+
f"Model '{model_key}' is being deleted; cannot start a new session."
15+
)
1116
_active_sessions[model_key] = _active_sessions.get(model_key, 0) + 1
1217

1318

@@ -23,3 +28,17 @@ def mark_model_session_inactive(model_key: str) -> None:
2328
def get_active_session_count(model_key: str) -> int:
2429
with _registry_lock:
2530
return _active_sessions.get(model_key, 0)
31+
32+
33+
def try_mark_model_for_deletion(model_key: str) -> int | None:
34+
with _registry_lock:
35+
count = _active_sessions.get(model_key, 0)
36+
if count > 0:
37+
return count
38+
_models_pending_deletion.add(model_key)
39+
return None
40+
41+
42+
def release_model_deletion_mark(model_key: str) -> None:
43+
with _registry_lock:
44+
_models_pending_deletion.discard(model_key)

backend/app/routes/models.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
import json
33
import uuid
44
import asyncio
5-
from typing import Dict, Optional
5+
from typing import Dict
66
from dataclasses import dataclass, field
77
from datetime import datetime, timezone
88
from fastapi import APIRouter, HTTPException, status
99
from pydantic import BaseModel
1010
from fastapi.responses import StreamingResponse
1111
from app.models.model_registry import MODEL_REGISTRY, TIER_MODELS, get_model_path
12-
from app.models.session_registry import get_active_session_count
12+
from app.models.session_registry import (
13+
try_mark_model_for_deletion,
14+
release_model_deletion_mark,
15+
)
1316
from app.utils.hardware_detect import get_hardware_info
1417
from app.utils.model_downloader import ensure_model
1518
import logging
@@ -113,8 +116,10 @@ async def delete_model(model_key: str):
113116
)
114117

115118
path = get_model_path(model_key)
116-
active_session_count = get_active_session_count(model_key)
117-
if active_session_count > 0:
119+
120+
# Check no sessions are active and reserve the model for deletion.
121+
active_session_count = try_mark_model_for_deletion(model_key)
122+
if active_session_count is not None:
118123
raise HTTPException(
119124
status_code=status.HTTP_409_CONFLICT,
120125
detail=(
@@ -123,21 +128,27 @@ async def delete_model(model_key: str):
123128
),
124129
)
125130

126-
if os.path.exists(path):
127-
try:
128-
await asyncio.to_thread(os.remove, path)
131+
try:
132+
if os.path.exists(path):
133+
try:
134+
await asyncio.to_thread(os.remove, path)
135+
return {
136+
"success": True,
137+
"message": f"Model {model_key} deleted successfully.",
138+
}
139+
except Exception as e:
140+
logger.error(f"Failed to delete model {model_key}: {e}")
141+
raise HTTPException(
142+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
143+
detail=f"Failed to delete model: {str(e)}",
144+
)
145+
else:
129146
return {
130147
"success": True,
131-
"message": f"Model {model_key} deleted successfully.",
148+
"message": f"Model {model_key} already not present.",
132149
}
133-
except Exception as e:
134-
logger.error(f"Failed to delete model {model_key}: {e}")
135-
raise HTTPException(
136-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
137-
detail=f"Failed to delete model: {str(e)}",
138-
)
139-
else:
140-
return {"success": True, "message": f"Model {model_key} already not present."}
150+
finally:
151+
release_model_deletion_mark(model_key)
141152

142153

143154
@router.post("/setup")

backend/app/utils/model_bootstrap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ async def _ensure_ai_tagging_models_async() -> None:
3131

3232

3333
def ensure_ai_tagging_models() -> None:
34-
asyncio.run(_ensure_ai_tagging_models_async())
34+
asyncio.run(_ensure_ai_tagging_models_async())

0 commit comments

Comments
 (0)