-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_nidsnet.py
More file actions
55 lines (42 loc) · 1.77 KB
/
test_nidsnet.py
File metadata and controls
55 lines (42 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
import torch
from fetch_grasp.utils import PROJECT_ROOT, OBJECT_CLASSES
from fetch_grasp.utils.commons import (
read_rgb_image,
write_rgb_image,
write_mask_image,
draw_annotated_image,
extract_masks_from_labels,
write_data_to_yaml,
)
def run_nidsnet_once(nids_mod, image_rgb):
_, labels = nids_mod.step(image_rgb)
labels = labels.cpu().numpy().astype(np.uint8)
masks = extract_masks_from_labels(labels)
obj_names = [OBJECT_CLASSES[int(i)] for i in np.unique(labels) if i != 0]
labels_vis = draw_annotated_image(image_rgb, masks=masks, labels=obj_names)
return masks, obj_names, labels, labels_vis
def initialize_nidsnet():
from fetch_grasp.wrappers.nidsnet import NIDS, feat_dict, weight_adapter_path
object_features = torch.Tensor(feat_dict["features"]).view(-1, 42, 1024).cuda()
model = NIDS(template_features=object_features, use_adapter=True, adapter_path=weight_adapter_path)
return model
if __name__ == "__main__":
# Directories
models_dir = f"{PROJECT_ROOT}/datasets/models"
save_dir = PROJECT_ROOT / "demo/ros"
cam_K_file = save_dir / "cam_K.txt"
color_file = save_dir / "color_image.png"
# Initialize NIDS-Net
nids_model = initialize_nidsnet()
# Load cam_K and color image
cam_K = np.loadtxt(cam_K_file, dtype=np.float32)
rgb = read_rgb_image(color_file)
# Run NIDS-Net Inference
masks, obj_names, labels, labels_vis = run_nidsnet_once(nids_model, rgb)
# Save results
write_mask_image(f"{save_dir}/mask_image_nidsnet.png", labels)
write_rgb_image(f"{save_dir}/mask_image_nidsnet_vis.png", labels_vis)
write_data_to_yaml(
f"{save_dir}/nidsnet_class_names.yaml", {int(i): OBJECT_CLASSES[int(i)] for i in np.unique(labels) if i != 0}
)