Skip to content

Commit 9aa98a3

Browse files
No public description
PiperOrigin-RevId: 875366955
1 parent b61e63a commit 9aa98a3

7 files changed

Lines changed: 357 additions & 18 deletions

File tree

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
16+
#
17+
# Licensed under the Apache License, Version 2.0 (the "License");
18+
# you may not use this file except in compliance with the License.
19+
# You may obtain a copy of the License at
20+
#
21+
# http://www.apache.org/licenses/LICENSE-2.0
22+
#
23+
# Unless required by applicable law or agreed to in writing, software
24+
# distributed under the License is distributed on an "AS IS" BASIS,
25+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26+
# See the License for the specific language governing permissions and
27+
# limitations under the License.
28+
29+
"""Extract properties from each object mask and detect its color."""
30+
31+
from typing import List, Tuple, TypeVar
32+
33+
import numpy as np
34+
import numpy.typing as npt
35+
from skimage import color as skimage_color
36+
from sklearn import cluster as sklearn_cluster
37+
from sklearn import neighbors as sklearn_neighbors
38+
import webcolors
39+
40+
DType = TypeVar('DType', bound=np.generic)
41+
# Color representation as numpy array of 3 elements of float64
42+
# Those values could be in different scales like
43+
# RGB ([0.0,255.0], [0.0,255.0], [0.0 to 255.0])
44+
# LAB ([0.0,100], [-128,127], [-128,127])
45+
# NColor = Annotated[npt.NDArray[DType], Literal[3]][np.float64]
46+
NColor = np.ndarray
47+
48+
49+
PROPERTIES = [
50+
'area',
51+
'bbox',
52+
'convex_area',
53+
'bbox_area',
54+
'major_axis_length',
55+
'minor_axis_length',
56+
'eccentricity',
57+
'centroid',
58+
]
59+
60+
GENERIC_COLORS = [
61+
('black', '#000000'),
62+
('green', '#008000'),
63+
('green', '#00ff00'), # lime
64+
('green', '#3cb371'), # mediumseagreen
65+
('green', '#2E8B57'), # seagreen
66+
('green', '#8FBC8B'), # darkseagreen
67+
('green', '#adff2f'), # olive
68+
('green', '#008080'), # Teal
69+
('green', '#808000'),
70+
('blue', '#000080'), # navy
71+
('blue', '#00008b'), # darkblue
72+
('blue', '#4682b4'), # steelblue
73+
('blue', '#40E0D0'), # turquoise
74+
('blue', '#00FFFF'), # cyan
75+
('blue', '#00ffff'), # aqua
76+
('blue', '#6495ED'), # cornflowerBlue
77+
('blue', '#4169E1'), # royalBlue
78+
('blue', '#87CEFA'), # lightSkyBlue
79+
('blue', '#4682B4'), # steelBlue
80+
('blue', '#B0C4DE'), # lightSteelBlue
81+
('blue', '#87CEEB'), # skyblue
82+
('blue', '#0000CD'), # mediumBlue
83+
('blue', '#0000ff'),
84+
('purple', '#800080'),
85+
('purple', '#9370db'), # mediumpurple
86+
('purple', '#8B008B'), # darkMagenta
87+
('purple', '#4B0082'), # indigo
88+
('red', '#ff0000'),
89+
('red', '#B22222'), # fireBrick
90+
('red', '#DC143C'), # fireBrick
91+
('red', '#8B0000'), # crimson
92+
('red', '#CD5C5C'), # indianred
93+
('red', '#F08080'), # lightCoral
94+
('red', '#FA8072'), # salmon
95+
('red', '#E9967A'), # darkSalmon
96+
('red', '#FFA07A'), # lightSalmon
97+
('gray', '#c0c0c0'), # silver,
98+
('gray', '#a9a9a9'), # +darkgray
99+
('gray', '#708090'), # +slategray
100+
('blue', '#778899'), # +lightslategray
101+
('white', '#ffffff'),
102+
('white', '#F5F5DC'), # beige
103+
('white', '#FFFAFA'), # snow
104+
('white', '#F0F8FF'), # aliceBlue
105+
('white', '#FFE4E1'), # mistyRose
106+
('yellow', '#ffff00'),
107+
('yellow', '#ffffe0'), # lightyellow
108+
('yellow', '#8B8000'), # darkyellow,
109+
('orange', '#ffa500'),
110+
('orange', '#ff8c00'), # darkorange
111+
('pink', '#ffc0cb'),
112+
('pink', '#ff00ff'), # fuchsia
113+
('pink', '#C71585'), # mediumVioletRed
114+
('pink', '#DB7093'), # paleVioletRed
115+
('pink', '#FFB6C1'), # lightPink
116+
('pink', '#FF69B4'), # hotPink
117+
('pink', '#FF1493'), # deepPink
118+
('pink', '#BC8F8F'), # rosybrown
119+
('brown', '#a52a2a'),
120+
('brown', '#8b4513'), # saddlebrown
121+
('brown', '#f4a460'), # sandybrown
122+
('brown', '#800000'), # maroon
123+
]
124+
125+
126+
def find_dominant_color(
127+
image: np.ndarray, black_threshold: int = 50
128+
) -> Tuple[int, int, int]:
129+
"""Determines the dominant color in a given image.
130+
131+
Args:
132+
image: An array representation of the image.
133+
black_threshold: The intensity threshold below which pixels are considered
134+
'black' or near-black.
135+
136+
Returns:
137+
The dominant RGB color in the format (R, G, B).
138+
"""
139+
pixels = image.reshape(-1, 3)
140+
141+
# Filter out black pixels based on the threshold
142+
non_black_pixels = pixels[(pixels > black_threshold).any(axis=1)]
143+
144+
if non_black_pixels.size:
145+
kmeans = sklearn_cluster.KMeans(
146+
n_clusters=1, n_init=10, random_state=0
147+
).fit(non_black_pixels)
148+
dominant_color = kmeans.cluster_centers_[0].astype(int)
149+
else:
150+
dominant_color = np.array([0, 0, 0], dtype=int)
151+
return tuple(dominant_color)
152+
153+
154+
def rgb_int_to_lab(rgb_int_color: Tuple[int, int, int]) -> NColor:
155+
"""Convert RGB color to LAB color space.
156+
157+
Args:
158+
rgb_int_color: RGB tuple color e.g. (128,128,128)
159+
160+
Returns:
161+
Numpy array of 3 elements that contains LAB color space.
162+
"""
163+
return skimage_color.rgb2lab(
164+
(rgb_int_color[0] / 255, rgb_int_color[1] / 255, rgb_int_color[2] / 255)
165+
)
166+
167+
168+
def color_distance(
169+
a: Tuple[int, int, int], b: Tuple[int, int, int]
170+
) -> np.ndarray:
171+
"""The color distance following the ciede2000 formula.
172+
173+
See: https://en.wikipedia.org/wiki/Color_difference#CIEDE2000
174+
175+
Args:
176+
a: Color a
177+
b: Color b
178+
179+
Returns:
180+
The distance between color a and b
181+
"""
182+
return skimage_color.deltaE_ciede2000(a, b, kC=0.6)
183+
184+
185+
def build_color_lab_list(
186+
generic_colors: List[Tuple[str, str]],
187+
) -> Tuple[npt.NDArray[np.str_], List[NColor]]:
188+
"""Get Simple colors names and lab values.
189+
190+
Args:
191+
generic_colors: List of colors in this format (color_name, rgb_value in hex)
192+
e.g. [ ('black', '#000000'), ('green', '#008000'), ]
193+
194+
Returns:
195+
Numpy array of strings that contains color names
196+
['black', 'green']
197+
List of color lab values in the format of Numpy array of 3 elements
198+
e.g.
199+
[
200+
np.array([0., 0., 0.]),
201+
np.array([ 46.2276577 , -51.69868348, 49.89707556])
202+
]
203+
"""
204+
names: list[str] = []
205+
lab_values = []
206+
for color_name, color_hex in generic_colors:
207+
names.append(color_name)
208+
hex_color = webcolors.hex_to_rgb(color_hex)
209+
lab_values.append(rgb_int_to_lab(hex_color))
210+
color_names = np.array(names)
211+
return color_names, lab_values
212+
213+
214+
def get_generic_color_name(
215+
rgb_colors: List[Tuple[int, int, int]],
216+
generic_colors: List[Tuple[str, str]] | None = None,
217+
) -> List[str]:
218+
"""Retrieves generic names of given RGB colors.
219+
220+
Estimates the closest matching color name.
221+
222+
Args:
223+
rgb_colors: A list of RGB values for which to retrieve the name.
224+
generic_colors: A list of color names and their RGB values in hex.
225+
226+
Returns:
227+
The list of closest color names.
228+
229+
Example: get_generic_color_name([(255, 0, 0), (0,0,0)])
230+
['red','black']
231+
"""
232+
names, rgb_simple_colors = build_color_lab_list(
233+
generic_colors or GENERIC_COLORS
234+
)
235+
tree = sklearn_neighbors.BallTree(rgb_simple_colors, metric=color_distance)
236+
rgb_query = [*map(rgb_int_to_lab, rgb_colors)]
237+
_, index = tree.query(rgb_query)
238+
return [x[0] for x in names[index]]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Copyright 2026 The TensorFlow Authors. All Rights Reserved.
16+
#
17+
# Licensed under the Apache License, Version 2.0 (the "License");
18+
# you may not use this file except in compliance with the License.
19+
# You may obtain a copy of the License at
20+
#
21+
# http://www.apache.org/licenses/LICENSE-2.0
22+
#
23+
# Unless required by applicable law or agreed to in writing, software
24+
# distributed under the License is distributed on an "AS IS" BASIS,
25+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26+
# See the License for the specific language governing permissions and
27+
# limitations under the License.
28+
29+
import unittest
30+
import numpy as np
31+
from official.projects.waste_identification_ml.Deploy.detr_cloud_deployment.client import color_extraction
32+
33+
34+
class ColorExtractionTest(unittest.TestCase):
35+
36+
def test_find_dominant_color_with_non_black_pixels(self):
37+
# Create an image with a clear dominant color (Red)
38+
image = np.zeros((10, 10, 3), dtype=np.uint8)
39+
image[0:5, 0:5] = [255, 0, 0] # Top-left quarter is Red
40+
image[5:10, 5:10] = [100, 0, 0] # Bottom-right quarter is dark Red
41+
dominant_color = color_extraction.find_dominant_color(image)
42+
self.assertEqual(dominant_color, (177, 0, 0))
43+
44+
def test_find_dominant_color_with_only_black_pixels(self):
45+
image = np.zeros((10, 10, 3), dtype=np.uint8)
46+
dominant_color = color_extraction.find_dominant_color(
47+
image, black_threshold=50
48+
)
49+
self.assertEqual(dominant_color, (0, 0, 0))
50+
51+
def test_rgb_int_to_lab(self):
52+
rgb = (255, 255, 255)
53+
lab = color_extraction.rgb_int_to_lab(rgb)
54+
# White in LAB is approx (100, 0, 0)
55+
self.assertIsInstance(lab, np.ndarray)
56+
self.assertEqual(lab.shape, (3,))
57+
np.testing.assert_allclose(lab, [100.0, 0.0, 0.0], atol=1e-2)
58+
59+
def test_color_distance(self):
60+
color_a = (100, 0, 0) # LAB
61+
color_b = (100, 0, 0) # LAB
62+
distance = color_extraction.color_distance(color_a, color_b)
63+
self.assertEqual(distance, 0.0)
64+
color_c = (0, 0, 0)
65+
distance_diff = color_extraction.color_distance(color_a, color_c)
66+
self.assertGreater(distance_diff, 0.0)
67+
68+
def test_build_color_lab_list(self):
69+
generic_colors = [('black', '#000000'), ('white', '#ffffff')]
70+
names, lab_values = color_extraction.build_color_lab_list(generic_colors)
71+
np.testing.assert_array_equal(names, ['black', 'white'])
72+
self.assertEqual(len(lab_values), 2)
73+
np.testing.assert_allclose(lab_values[0], [0.0, 0.0, 0.0], atol=1e-2)
74+
np.testing.assert_allclose(lab_values[1], [100.0, 0.0, 0.0], atol=1e-2)
75+
76+
def test_get_generic_color_name(self):
77+
rgb_colors = [(255, 0, 0), (0, 0, 255)] # Red, Blue
78+
generic_colors = [
79+
('red', '#ff0000'),
80+
('blue', '#0000ff'),
81+
('green', '#00ff00'),
82+
]
83+
names = color_extraction.get_generic_color_name(rgb_colors, generic_colors)
84+
self.assertEqual(names, ['red', 'blue'])
85+
86+
87+
if __name__ == '__main__':
88+
unittest.main()

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/inference_pipeline.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@
2727
# limitations under the License.
2828
"""Pipeline to run the prediction on the images folder with Triton server."""
2929

30-
import os
30+
import cuml.accel # pylint: disable=g-bad-import-order, g-import-not-at-top
3131

32+
cuml.accel.install() # pylint: disable=g-bad-import-order, g-import-not-at-top
33+
34+
import os # pylint: disable=g-bad-import-order, g-import-not-at-top
3235
from absl import app
3336
from absl import flags
3437
from big_query_ops import BigQueryManager
@@ -141,7 +144,7 @@ def main(_) -> None:
141144
)
142145

143146
# Continue to next image if no objects detected
144-
if not results["class_names"].any():
147+
if results["class_names"].size == 0:
145148
logger.info(f"No objects detected in {os.path.basename(image_path)}")
146149
continue
147150

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/requirements.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ source myenv/bin/activate
3434

3535
echo "Activated python environment, installing dependencies."
3636

37-
pip install --no-cache-dir natsort absl-py opencv-python pandas pandas-gbq \
38-
google-cloud-bigquery google-auth trackpy google-cloud-storage \
39-
scikit-image scikit-learn webcolors==1.13 ffmpeg-python tritonclient[all] \
40-
supervision==0.26.1 pillow==12.0.0
37+
pip install -r requirements.txt
4138

4239
# Clone TensorFlow Model Garden if the 'models' directory does not exist
4340
if [ ! -d "models" ]; then
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
--extra-index-url https://pypi.nvidia.com
2+
3+
natsort==8.4.0
4+
absl-py==2.4.0
5+
opencv-python==4.13.0.92
6+
pandas==2.3.3
7+
pandas-gbq==0.33.0
8+
google-cloud-bigquery==3.40.1
9+
google-auth==2.48.0
10+
google-cloud-storage==3.9.0
11+
scikit-image==0.25.2
12+
scikit-learn==1.7.2
13+
webcolors==1.13
14+
ffmpeg-python==0.2.0
15+
tritonclient[all]==2.65.0
16+
supervision==0.26.1
17+
pillow==12.0.0
18+
trackpy==0.7
19+
cupy-cuda12x[cuda_dlls]
20+
cupy-cuda12x[ctk]
21+
cuml-cu12==25.12.*

official/projects/waste_identification_ml/Deploy/detr_cloud_deployment/client/triton_server_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def predict(
249249
raw_outputs, confidence_threshold, max_boxes
250250
)
251251

252-
if results['labels'].any():
252+
if results['labels'].size != 0:
253253
# Scale to output dimensions
254254
results = self._scale_bbox_and_masks(results, output_dims)
255255

0 commit comments

Comments
 (0)