forked from google/neuroglancer
-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathinteractive_inference.py
More file actions
executable file
·117 lines (101 loc) · 4.43 KB
/
interactive_inference.py
File metadata and controls
executable file
·117 lines (101 loc) · 4.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python
"""Example of displaying interactive image-to-image "inference" results.
shift+mousedown0 triggers the inference result to be computed for the patch
centered around the mouse position, and then displayed in neuroglancer.
In this example, the inference result is actually just a distance transform
computed from the ground truth segmentation, but in actual use the inference
result may be computed using SciPy, Tensorflow, PyTorch, etc.
The cloudvolume library (https://github.com/seung-lab/cloud-volume) is used to
retrieve patches of the ground truth volume.
The zarr library is used to represent the sparse in-memory array containing the
computed inference results that are displayed in neuroglancer.
"""
import argparse
import time
import neuroglancer
import cloudvolume
import zarr
import numpy as np
import scipy.ndimage
class InteractiveInference(object):
def __init__(self):
viewer = self.viewer = neuroglancer.Viewer()
viewer.actions.add('inference', self._do_inference)
self.gt_vol = cloudvolume.CloudVolume(
'https://storage.googleapis.com/neuroglancer-public-data/flyem_fib-25/ground_truth',
mip=0,
bounded=True,
progress=False,
provenance={})
self.dimensions = neuroglancer.CoordinateSpace(
names=['x', 'y', 'z'],
units='nm',
scales=self.gt_vol.resolution,
)
self.inf_results = zarr.zeros(
self.gt_vol.bounds.to_list()[3:], chunks=(64, 64, 64), dtype=np.uint8)
self.inf_volume = neuroglancer.LocalVolume(
data=self.inf_results, dimensions=self.dimensions)
with viewer.config_state.txn() as s:
s.input_event_bindings.data_view['shift+mousedown0'] = 'inference'
with viewer.txn() as s:
s.layers['image'] = neuroglancer.ImageLayer(
source='precomputed://gs://neuroglancer-public-data/flyem_fib-25/image',
)
s.layers['ground_truth'] = neuroglancer.SegmentationLayer(
source='precomputed://gs://neuroglancer-public-data/flyem_fib-25/ground_truth',
)
s.layers['ground_truth'].visible = False
s.layers['inference'] = neuroglancer.ImageLayer(
source=self.inf_volume,
shader='''
void main() {
float v = toNormalized(getDataValue(0));
vec4 rgba = vec4(0,0,0,0);
if (v != 0.0) {
rgba = vec4(colormapJet(v), 1.0);
}
emitRGBA(rgba);
}
''',
)
def _do_inference(self, action_state):
pos = action_state.mouse_voxel_coordinates
if pos is None:
return
patch_size = np.array((128, ) * 3, np.int64)
spos = pos - patch_size // 2
epos = spos + patch_size
slice_expr = np.s_[int(spos[0]):int(epos[0]),
int(spos[1]):int(epos[1]),
int(spos[2]):int(epos[2])]
gt_data = self.gt_vol[slice_expr][..., 0]
boundary_mask = gt_data == 0
boundary_mask[:, :, :-1] |= (gt_data[:, :, :-1] != gt_data[:, :, 1:])
boundary_mask[:, :, 1:] |= (gt_data[:, :, :-1] != gt_data[:, :, 1:])
boundary_mask[:, :-1, :] |= (gt_data[:, :-1, :] != gt_data[:, 1:, :])
boundary_mask[:, 1:, :] |= (gt_data[:, :-1, :] != gt_data[:, 1:, :])
boundary_mask[:-1, :, :] |= (gt_data[:-1, :, :] != gt_data[1:, :, :])
boundary_mask[1:, :, :] |= (gt_data[:-1, :, :] != gt_data[1:, :, :])
dist_transform = scipy.ndimage.morphology.distance_transform_edt(~boundary_mask)
self.inf_results[slice_expr] = 1 + np.cast[np.uint8](
np.minimum(dist_transform, 5) / 5.0 * 254)
self.inf_volume.invalidate()
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument(
'-a',
'--bind-address',
help='Bind address for Python web server. Use 127.0.0.1 (the default) to restrict access '
'to browers running on the local machine, use 0.0.0.0 to permit access from remote browsers.')
ap.add_argument(
'--static-content-url', help='Obtain the Neuroglancer client code from the specified URL.')
args = ap.parse_args()
if args.bind_address:
neuroglancer.set_server_bind_address(args.bind_address)
if args.static_content_url:
neuroglancer.set_static_content_source(url=args.static_content_url)
inf = InteractiveInference()
print(inf.viewer)
while True:
time.sleep(1000)