Skip to content

Commit fb4c852

Browse files
Xyc2016claude
andcommitted
Fix compatibility with PyTorch 2.6+ and matplotlib >=3.6
- Fix torch.load() weights_only default change (PyTorch 2.6+) Add weights_only=False to load official model checkpoints - Fix FigureCanvas.tostring_rgb() removal (matplotlib >=3.6) Replace with buffer_rgba() and update color channel handling - Add missing psutil dependency to requirements.txt Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b4ed20c commit fb4c852

4 files changed

Lines changed: 18 additions & 21 deletions

File tree

fastsam/prompt.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,11 @@ def plot_to_result(self,
166166

167167
plt.axis('off')
168168
fig = plt.gcf()
169-
plt.draw()
170-
171-
try:
172-
buf = fig.canvas.tostring_rgb()
173-
except AttributeError:
174-
fig.canvas.draw()
175-
buf = fig.canvas.tostring_rgb()
169+
fig.canvas.draw()
170+
buf = fig.canvas.buffer_rgba()
176171
cols, rows = fig.canvas.get_width_height()
177-
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
178-
result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
172+
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 4)
173+
result = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
179174
plt.close()
180175
return result
181176

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ scipy>=1.4.1
88
torch>=1.7.0
99
torchvision>=0.8.1
1010
tqdm>=4.64.0
11+
psutil
1112

1213
pandas>=1.1.4
1314
seaborn>=0.11.0

ultralytics/nn/tasks.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
import torch.nn as nn
9+
import pickle
910

1011
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
1112
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
@@ -515,7 +516,7 @@ def torch_safe_load(weight):
515516
check_suffix(file=weight, suffix='.pt')
516517
file = attempt_download_asset(weight) # search online if missing locally
517518
try:
518-
return torch.load(file, map_location='cpu'), file # load
519+
return torch.load(file, map_location='cpu', weights_only=True), file # load
519520
except ModuleNotFoundError as e: # e.name is missing module name
520521
if e.name == 'models':
521522
raise TypeError(
@@ -530,7 +531,12 @@ def torch_safe_load(weight):
530531
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
531532
check_requirements(e.name) # install missing module
532533

533-
return torch.load(file, map_location='cpu'), file # load
534+
return torch.load(file, map_location='cpu', weights_only=True), file # load
535+
except (pickle.UnpicklingError, RuntimeError) as e:
536+
LOGGER.warning(f"WARNING ⚠️ {weight} requires non-safe pickle loading. "
537+
f"Falling back to weights_only=False. "
538+
f"Only load weights from trusted sources.")
539+
return torch.load(file, map_location='cpu', weights_only=False), file # load
534540

535541

536542
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):

utils/tools.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,12 @@ def fast_process(
177177
os.makedirs(save_path)
178178
plt.axis("off")
179179
fig = plt.gcf()
180-
plt.draw()
181-
182-
try:
183-
buf = fig.canvas.tostring_rgb()
184-
except AttributeError:
185-
fig.canvas.draw()
186-
buf = fig.canvas.tostring_rgb()
187-
180+
fig.canvas.draw()
181+
buf = fig.canvas.buffer_rgba()
182+
188183
cols, rows = fig.canvas.get_width_height()
189-
img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
190-
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
184+
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 4)
185+
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR))
191186

192187

193188
# CPU post process

0 commit comments

Comments
 (0)