Skip to content

Commit 9389011

Browse files
Xyc2016claude
andcommitted
Fix compatibility with PyTorch 2.6+ and matplotlib >=3.6
- Fix torch.load() weights_only default change (PyTorch 2.6+) Add weights_only=False to load official model checkpoints - Fix FigureCanvas.tostring_rgb() removal (matplotlib >=3.6) Replace with buffer_rgba() and update color channel handling - Add missing psutil dependency to requirements.txt Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b4ed20c commit 9389011

4 files changed

Lines changed: 15 additions & 13 deletions

File tree

fastsam/prompt.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,16 @@ def plot_to_result(self,
169169
plt.draw()
170170

171171
try:
172-
buf = fig.canvas.tostring_rgb()
172+
buf = fig.canvas.buffer_rgba()
173+
cols, rows = fig.canvas.get_width_height()
173174
except AttributeError:
174175
fig.canvas.draw()
175-
buf = fig.canvas.tostring_rgb()
176-
cols, rows = fig.canvas.get_width_height()
177-
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
178-
result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
176+
buf = fig.canvas.buffer_rgba()
177+
cols, rows = fig.canvas.get_width_height()
178+
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 4)
179+
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
179180
plt.close()
180-
return result
181+
return img_array
181182

182183
# Remark for refactoring: IMO a function should do one thing only, storing the image and plotting should be seperated and do not necessarily need to be class functions but standalone utility functions that the user can chain in his scripts to have more fine-grained control.
183184
def plot(self,

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ scipy>=1.4.1
88
torch>=1.7.0
99
torchvision>=0.8.1
1010
tqdm>=4.64.0
11+
psutil
1112

1213
pandas>=1.1.4
1314
seaborn>=0.11.0

ultralytics/nn/tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def torch_safe_load(weight):
515515
check_suffix(file=weight, suffix='.pt')
516516
file = attempt_download_asset(weight) # search online if missing locally
517517
try:
518-
return torch.load(file, map_location='cpu'), file # load
518+
return torch.load(file, map_location='cpu', weights_only=False), file # load
519519
except ModuleNotFoundError as e: # e.name is missing module name
520520
if e.name == 'models':
521521
raise TypeError(
@@ -530,7 +530,7 @@ def torch_safe_load(weight):
530530
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
531531
check_requirements(e.name) # install missing module
532532

533-
return torch.load(file, map_location='cpu'), file # load
533+
return torch.load(file, map_location='cpu', weights_only=False), file # load
534534

535535

536536
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):

utils/tools.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,14 @@ def fast_process(
180180
plt.draw()
181181

182182
try:
183-
buf = fig.canvas.tostring_rgb()
183+
buf = fig.canvas.buffer_rgba()
184184
except AttributeError:
185185
fig.canvas.draw()
186-
buf = fig.canvas.tostring_rgb()
187-
186+
buf = fig.canvas.buffer_rgba()
187+
188188
cols, rows = fig.canvas.get_width_height()
189-
img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
190-
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
189+
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 4)
190+
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR))
191191

192192

193193
# CPU post process

0 commit comments

Comments
 (0)