Skip to content

Commit d68206b

Browse files
acceptable extraction of production data points from edited image
1 parent 939a8a9 commit d68206b

1 file changed

Lines changed: 48 additions & 16 deletions

File tree

src/geophires_docs/generate_fervo_project_red_2026_docs.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pandas as pd
88
from scipy.interpolate import interp1d
9+
from scipy.ndimage import maximum_filter
910

1011
from geophires_docs import _PROJECT_ROOT
1112

@@ -80,37 +81,51 @@ def _extract_red_circles(
8081
plot_mask: np.ndarray,
8182
pixel_to_data,
8283
) -> pd.DataFrame:
83-
img = cv2.imread(str(img_path))
84+
img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED)
8485
if img is None:
8586
raise FileNotFoundError(f'Could not load image at {img_path}')
8687

87-
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
88-
89-
lower_red1 = np.array([0, 50, 50])
90-
upper_red1 = np.array([15, 255, 255])
91-
lower_red2 = np.array([165, 50, 50])
88+
if len(img.shape) == 3 and img.shape[2] == 4:
89+
alpha = img[:, :, 3]
90+
_, mask_alpha = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
91+
hsv = cv2.cvtColor(img[:, :, :3], cv2.COLOR_BGR2HSV)
92+
else:
93+
mask_alpha = np.ones(img.shape[:2], dtype=np.uint8) * 255
94+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
95+
96+
# Widened HSV bounds to capture anti-aliased/faded brush edges
97+
lower_red1 = np.array([0, 20, 20])
98+
upper_red1 = np.array([20, 255, 255])
99+
lower_red2 = np.array([160, 20, 20])
92100
upper_red2 = np.array([180, 255, 255])
93101

94102
mask_red1 = cv2.inRange(hsv, lower_red1, upper_red1)
95103
mask_red2 = cv2.inRange(hsv, lower_red2, upper_red2)
96104
mask_red = cv2.bitwise_or(mask_red1, mask_red2)
105+
mask_red = cv2.bitwise_and(mask_red, mask_alpha)
97106
mask_red = cv2.bitwise_and(mask_red, plot_mask)
98107

99-
y_coords, x_coords = np.where(mask_red > 0)
108+
# Use distance transform to find the central ridge line of the brush strokes
109+
dist_transform = cv2.distanceTransform(mask_red, cv2.DIST_L2, 5)
100110

101-
if len(x_coords) == 0:
102-
_log.warning('No red pixels found in the production data mask.')
103-
return pd.DataFrame(columns=['Time_Years', 'Temperature_C'])
111+
# Find the peaks (ridges) using a small 3x3 max filter
112+
local_max = maximum_filter(dist_transform, size=3) == dist_transform
113+
# Filter out absolute noise
114+
peak_mask = local_max & (dist_transform > 1.0)
104115

105-
df_pixels = pd.DataFrame({'x': x_coords, 'y': y_coords})
116+
y_coords, x_coords = np.where(peak_mask)
117+
centers_px = [(int(x), int(y)) for x, y in zip(x_coords, y_coords)]
106118

107-
bin_size = int(_HOUGH_MIN_DIST_PX)
108-
df_pixels['x_binned'] = (df_pixels['x'] // bin_size) * bin_size + (bin_size // 2)
109-
centerline = df_pixels.groupby('x_binned', as_index=False)[['x', 'y']].mean()
119+
if not centers_px:
120+
_log.warning('No valid pixels found in the production data mask.')
121+
return pd.DataFrame(columns=['Time_Years', 'Temperature_C'])
110122

111-
_log.info(f'Red-marker detection: Extracted {len(centerline)} binned centerline points from edited mask.')
123+
# Space the extracted points evenly along the detected ridge
124+
deduped_centers_px = _dedupe_centers(centers_px, min_dist_px=_HOUGH_MIN_DIST_PX)
112125

113-
production_data = [pixel_to_data(row['x'], row['y']) for _, row in centerline.iterrows()]
126+
_log.info(f'Red-marker detection: Extracted {len(deduped_centers_px)} topological ridge points from edited mask.')
127+
128+
production_data = [pixel_to_data(cx, cy) for cx, cy in deduped_centers_px]
114129
df_prod = pd.DataFrame(production_data, columns=['Time_Years', 'Temperature_C'])
115130
return df_prod.sort_values('Time_Years').reset_index(drop=True)
116131

@@ -154,6 +169,23 @@ def _extract_black_dashed_line(hsv: np.ndarray, plot_mask: np.ndarray, pixel_to_
154169
return df_model
155170

156171

172+
def _dedupe_centers(centers_px: list[tuple[int, int]], min_dist_px: float) -> list[tuple[int, int]]:
173+
if not centers_px:
174+
return []
175+
176+
accepted: list[tuple[int, int]] = []
177+
min_dist_sq = min_dist_px * min_dist_px
178+
for cx, cy in centers_px:
179+
duplicate = False
180+
for ax, ay in accepted:
181+
if (cx - ax) * (cx - ax) + (cy - ay) * (cy - ay) < min_dist_sq:
182+
duplicate = True
183+
break
184+
if not duplicate:
185+
accepted.append((cx, cy))
186+
return accepted
187+
188+
157189
def _regenerate_graph_from_csv(
158190
production_csv_path: Path,
159191
model_csv_path: Path,

0 commit comments

Comments
 (0)