Skip to content

Commit 2485390

Browse files
committed
fix: detect snemi labels in check_files
1 parent 18f1af2 commit 2485390

2 files changed

Lines changed: 91 additions & 15 deletions

File tree

server_api/main.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,35 @@ def save_upload_to_tempfile(upload: UploadFile) -> pathlib.Path:
295295
return temp_path
296296

297297

298+
def _is_probable_label_volume(image_array) -> bool:
299+
import numpy as np
300+
301+
if not np.issubdtype(image_array.dtype, np.integer):
302+
return False
303+
304+
unique_values = np.unique(image_array)
305+
num_unique = len(unique_values)
306+
if num_unique == 0:
307+
return False
308+
309+
if num_unique == 2 and np.array_equal(unique_values, np.array([0, 1])):
310+
return True
311+
if num_unique == 2 and np.array_equal(unique_values, np.array([0, 255])):
312+
return True
313+
if num_unique < 50:
314+
return True
315+
316+
max_value = int(unique_values[-1])
317+
if max_value > 255 and num_unique <= 4096:
318+
return True
319+
320+
dtype_info = np.iinfo(image_array.dtype)
321+
if dtype_info.max > 255 and num_unique <= 1024:
322+
return True
323+
324+
return False
325+
326+
298327
@app.post("/neuroglancer")
299328
async def neuroglancer(req: Request):
300329
import neuroglancer
@@ -491,23 +520,9 @@ async def check_files(req: Request):
491520
print(f"Failed to read file: {e}")
492521
return {"error": f"Failed to open image: {str(e)}"}
493522

494-
# Heuristic for label detection:
495-
# 1. Must be integer type
496-
# 2. Low number of unique values (e.g. < 50) relative to size
497-
# 3. Or explicit binary (0, 255) or (0, 1)
498-
499523
unique_values = np.unique(image_array)
500524
num_unique = len(unique_values)
501-
is_integer = np.issubdtype(image_array.dtype, np.integer)
502-
503-
is_label = False
504-
if is_integer:
505-
if num_unique < 50:
506-
is_label = True
507-
elif np.array_equal(unique_values, np.array([0, 255])) or np.array_equal(
508-
unique_values, np.array([0, 1])
509-
):
510-
is_label = True
525+
is_label = _is_probable_label_volume(image_array)
511526

512527
if is_label:
513528
print(

tests/test_check_files_route.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import pathlib
2+
import unittest
3+
4+
from fastapi.testclient import TestClient
5+
6+
from server_api.main import app as server_api_app
7+
8+
9+
SNEMI_ROOT = pathlib.Path("/Users/adamg/seg.bio/testing_data/snemi")
10+
TEST_INPUT_PATH = SNEMI_ROOT / "image" / "test-input.tif"
11+
TRAIN_LABELS_PATH = SNEMI_ROOT / "seg" / "train-labels.tif"
12+
13+
14+
@unittest.skipUnless(TEST_INPUT_PATH.exists(), "SNEMI fixture is not available")
15+
class CheckFilesRouteTests(unittest.TestCase):
16+
def setUp(self):
17+
self.client = TestClient(server_api_app)
18+
19+
def test_check_files_marks_snemi_image_as_not_label(self):
20+
response = self.client.post(
21+
"/check_files",
22+
json={
23+
"filePath": str(TEST_INPUT_PATH),
24+
"folderPath": str(TEST_INPUT_PATH.parent),
25+
"name": TEST_INPUT_PATH.name,
26+
},
27+
)
28+
29+
self.assertEqual(response.status_code, 200)
30+
self.assertEqual(response.json(), {"label": False})
31+
32+
def test_check_files_marks_snemi_labels_as_label(self):
33+
response = self.client.post(
34+
"/check_files",
35+
json={
36+
"filePath": str(TRAIN_LABELS_PATH),
37+
"folderPath": str(TRAIN_LABELS_PATH.parent),
38+
"name": TRAIN_LABELS_PATH.name,
39+
},
40+
)
41+
42+
self.assertEqual(response.status_code, 200)
43+
self.assertEqual(response.json(), {"label": True})
44+
45+
def test_check_files_returns_error_payload_for_missing_path(self):
46+
missing_path = SNEMI_ROOT / "image" / "does-not-exist.tif"
47+
response = self.client.post(
48+
"/check_files",
49+
json={
50+
"filePath": str(missing_path),
51+
"folderPath": str(missing_path.parent),
52+
"name": missing_path.name,
53+
},
54+
)
55+
56+
self.assertEqual(response.status_code, 200)
57+
self.assertIn("error", response.json())
58+
59+
60+
if __name__ == "__main__":
61+
unittest.main()

0 commit comments

Comments
 (0)