@@ -156,17 +156,60 @@ def _try_cuda(self) -> bool:
156156 return False
157157
158158 def _try_rocm (self ) -> bool :
159- """Detect AMD GPU via rocm-smi or /opt/rocm."""
159+ """Detect AMD GPU via amd-smi (preferred) or rocm-smi."""
160+ has_amd_smi = shutil .which ("amd-smi" ) is not None
160161 has_rocm_smi = shutil .which ("rocm-smi" ) is not None
161162 has_rocm_dir = Path ("/opt/rocm" ).is_dir ()
162163
163- if not (has_rocm_smi or has_rocm_dir ):
164+ if not (has_amd_smi or has_rocm_smi or has_rocm_dir ):
164165 return False
165166
166167 self .backend = "rocm"
167168 self .device = "cuda" # ROCm exposes as CUDA in PyTorch
168169
169- if has_rocm_smi :
170+ # Strategy 1: amd-smi static --json (ROCm 6.3+/7.x, richest output)
171+ if has_amd_smi :
172+ try :
173+ result = subprocess .run (
174+ ["amd-smi" , "static" , "--json" ],
175+ capture_output = True , text = True , timeout = 10 ,
176+ )
177+ if result .returncode == 0 :
178+ import json as _json
179+ data = _json .loads (result .stdout )
180+ # amd-smi may return {"gpu_data": [...]} or a bare list
181+ gpu_list = data .get ("gpu_data" , data ) if isinstance (data , dict ) else data
182+ if isinstance (gpu_list , list ) and len (gpu_list ) > 0 :
183+ # Pick GPU with most VRAM (discrete > iGPU)
184+ def _vram_mb (g ):
185+ vram = g .get ("vram" , {}).get ("size" , {})
186+ if isinstance (vram , dict ):
187+ return int (vram .get ("value" , 0 ))
188+ return 0
189+
190+ best_gpu = max (gpu_list , key = _vram_mb )
191+ best_idx = gpu_list .index (best_gpu )
192+ asic = best_gpu .get ("asic" , {})
193+ vram = best_gpu .get ("vram" , {}).get ("size" , {})
194+
195+ self .gpu_name = asic .get ("market_name" , "AMD GPU" )
196+ self .gpu_memory_mb = int (vram .get ("value" , 0 )) if isinstance (vram , dict ) else 0
197+ self .detection_details ["amd_smi" ] = {
198+ "gpu_index" : best_idx ,
199+ "gfx_version" : asic .get ("target_graphics_version" , "" ),
200+ "total_gpus" : len (gpu_list ),
201+ }
202+
203+ # Pin to discrete GPU if multiple GPUs present
204+ if len (gpu_list ) > 1 :
205+ os .environ ["HIP_VISIBLE_DEVICES" ] = str (best_idx )
206+ os .environ ["ROCR_VISIBLE_DEVICES" ] = str (best_idx )
207+ _log (f"Multi-GPU: pinned to GPU { best_idx } ({ self .gpu_name } )" )
208+ except (subprocess .TimeoutExpired , FileNotFoundError , ValueError , Exception ) as e :
209+ _log (f"amd-smi probe failed: { e } " )
210+
211+ # Strategy 2: rocm-smi fallback (legacy ROCm <6.3)
212+ if not self .gpu_name and has_rocm_smi :
170213 try :
171214 result = subprocess .run (
172215 ["rocm-smi" , "--showproductname" , "--csv" ],
@@ -186,7 +229,6 @@ def _try_rocm(self) -> bool:
186229 capture_output = True , text = True , timeout = 10 ,
187230 )
188231 if result .returncode == 0 :
189- # Parse total VRAM
190232 for line in result .stdout .strip ().split ("\n " )[1 :]:
191233 parts = line .split ("," )
192234 if len (parts ) >= 2 :
@@ -296,11 +338,22 @@ def _fallback_cpu(self):
296338
297339 _log ("No GPU detected, using CPU backend" )
298340
341+ def _check_rocm_runtime (self ):
342+ """Verify onnxruntime has ROCm provider, not just CPU."""
343+ import onnxruntime
344+ providers = onnxruntime .get_available_providers ()
345+ if "ROCmExecutionProvider" in providers or "MIGraphXExecutionProvider" in providers :
346+ _log (f"onnxruntime ROCm providers: { providers } " )
347+ return True
348+ _log (f"onnxruntime providers: { providers } — ROCmExecutionProvider not found" )
349+ _log ("Fix: pip uninstall onnxruntime && pip install onnxruntime-rocm" )
350+ raise ImportError ("ROCmExecutionProvider not available" )
351+
299352 def _check_framework (self ) -> bool :
300353 """Check if the optimized inference runtime is importable."""
301354 checks = {
302355 "cuda" : lambda : __import__ ("tensorrt" ),
303- "rocm" : lambda : __import__ ( "onnxruntime" ),
356+ "rocm" : lambda : self . _check_rocm_runtime ( ),
304357 "mps" : lambda : __import__ ("coremltools" ),
305358 "intel" : lambda : __import__ ("openvino" ),
306359 "cpu" : lambda : __import__ ("onnxruntime" ),
0 commit comments