Skip to content

Commit cc2cae1

Browse files
Xyc2016claude
andcommitted
Fix compatibility with PyTorch 2.6+ and matplotlib >=3.6
- Fix torch.load() weights_only default change (PyTorch 2.6+) Add weights_only=False to load official model checkpoints - Fix FigureCanvas.tostring_rgb() removal (matplotlib >=3.6) Replace with buffer_rgba() and update color channel handling - Add missing psutil dependency to requirements.txt Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b4ed20c commit cc2cae1

4 files changed

Lines changed: 12 additions & 11 deletions

File tree

fastsam/prompt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,13 @@ def plot_to_result(self,
169169
plt.draw()
170170

171171
try:
172-
buf = fig.canvas.tostring_rgb()
172+
buf = fig.canvas.buffer_rgba()
173173
except AttributeError:
174174
fig.canvas.draw()
175-
buf = fig.canvas.tostring_rgb()
175+
buf = fig.canvas.buffer_rgba()
176176
cols, rows = fig.canvas.get_width_height()
177-
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3)
178-
result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
177+
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 4)
178+
result = cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR)
179179
plt.close()
180180
return result
181181

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ scipy>=1.4.1
88
torch>=1.7.0
99
torchvision>=0.8.1
1010
tqdm>=4.64.0
11+
psutil
1112

1213
pandas>=1.1.4
1314
seaborn>=0.11.0

ultralytics/nn/tasks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def torch_safe_load(weight):
515515
check_suffix(file=weight, suffix='.pt')
516516
file = attempt_download_asset(weight) # search online if missing locally
517517
try:
518-
return torch.load(file, map_location='cpu'), file # load
518+
return torch.load(file, map_location='cpu', weights_only=False), file # load
519519
except ModuleNotFoundError as e: # e.name is missing module name
520520
if e.name == 'models':
521521
raise TypeError(
@@ -530,7 +530,7 @@ def torch_safe_load(weight):
530530
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
531531
check_requirements(e.name) # install missing module
532532

533-
return torch.load(file, map_location='cpu'), file # load
533+
return torch.load(file, map_location='cpu', weights_only=False), file # load
534534

535535

536536
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):

utils/tools.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,14 @@ def fast_process(
180180
plt.draw()
181181

182182
try:
183-
buf = fig.canvas.tostring_rgb()
183+
buf = fig.canvas.buffer_rgba()
184184
except AttributeError:
185185
fig.canvas.draw()
186-
buf = fig.canvas.tostring_rgb()
187-
186+
buf = fig.canvas.buffer_rgba()
187+
188188
cols, rows = fig.canvas.get_width_height()
189-
img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
190-
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
189+
img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 4)
190+
cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGBA2BGR))
191191

192192

193193
# CPU post process

0 commit comments

Comments
 (0)