@@ -93,6 +93,7 @@ def __init__(
9393 sample_rate = 44100 ,
9494 use_soundfile = False ,
9595 use_autocast = False ,
96+ use_directml = False ,
9697 mdx_params = {"hop_length" : 1024 , "segment_size" : 256 , "overlap" : 0.25 , "batch_size" : 1 , "enable_denoise" : False },
9798 vr_params = {"batch_size" : 1 , "window_size" : 512 , "aggression" : 5 , "enable_tta" : False , "enable_post_process" : False , "post_process_threshold" : 0.2 , "high_end_process" : False },
9899 demucs_params = {"segment_size" : "Default" , "shifts" : 2 , "overlap" : 0.25 , "segments_enabled" : True },
@@ -179,6 +180,7 @@ def __init__(
179180
180181 self .use_soundfile = use_soundfile
181182 self .use_autocast = use_autocast
183+ self .use_directml = use_directml
182184
183185 # These are parameters which users may want to configure so we expose them to the top-level Separator class,
184186 # even though they are specific to a single model architecture
@@ -246,20 +248,24 @@ def log_onnxruntime_packages(self):
246248 onnxruntime_gpu_package = self .get_package_distribution ("onnxruntime-gpu" )
247249 onnxruntime_silicon_package = self .get_package_distribution ("onnxruntime-silicon" )
248250 onnxruntime_cpu_package = self .get_package_distribution ("onnxruntime" )
251+ onnxruntime_dml_package = self .get_package_distribution ("onnxruntime-directml" )
249252
250253 if onnxruntime_gpu_package is not None :
251254 self .logger .info (f"ONNX Runtime GPU package installed with version: { onnxruntime_gpu_package .version } " )
252255 if onnxruntime_silicon_package is not None :
253256 self .logger .info (f"ONNX Runtime Silicon package installed with version: { onnxruntime_silicon_package .version } " )
254257 if onnxruntime_cpu_package is not None :
255258 self .logger .info (f"ONNX Runtime CPU package installed with version: { onnxruntime_cpu_package .version } " )
259+ if onnxruntime_dml_package is not None :
260+ self .logger .info (f"ONNX Runtime DirectML package installed with version: { onnxruntime_dml_package .version } " )
256261
257262 def setup_torch_device (self , system_info ):
258263 """
259264 This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
260265 """
261266 hardware_acceleration_enabled = False
262267 ort_providers = ort .get_available_providers ()
268+ has_torch_dml_installed = self .get_package_distribution ("torch_directml" )
263269
264270 self .torch_device_cpu = torch .device ("cpu" )
265271
@@ -269,6 +275,11 @@ def setup_torch_device(self, system_info):
269275 elif hasattr (torch .backends , "mps" ) and torch .backends .mps .is_available () and system_info .processor == "arm" :
270276 self .configure_mps (ort_providers )
271277 hardware_acceleration_enabled = True
278+ elif self .use_directml and has_torch_dml_installed :
279+ import torch_directml
280+ if torch_directml .is_available ():
281+ self .configure_dml (ort_providers )
282+ hardware_acceleration_enabled = True
272283
273284 if not hardware_acceleration_enabled :
274285 self .logger .info ("No hardware acceleration could be configured, running in CPU mode" )
@@ -302,6 +313,21 @@ def configure_mps(self, ort_providers):
302313 else :
303314 self .logger .warning ("CoreMLExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled" )
304315
316+ def configure_dml (self , ort_providers ):
317+ """
318+ This method configures the DirectML device for PyTorch and ONNX Runtime, if available.
319+ """
320+ import torch_directml
321+ self .logger .info ("DirectML is available in Torch, setting Torch device to DirectML" )
322+ self .torch_device_dml = torch_directml .device ()
323+ self .torch_device = self .torch_device_dml
324+
325+ if "DmlExecutionProvider" in ort_providers :
326+ self .logger .info ("ONNXruntime has DmlExecutionProvider available, enabling acceleration" )
327+ self .onnx_execution_provider = ["DmlExecutionProvider" ]
328+ else :
329+ self .logger .warning ("DmlExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled" )
330+
305331 def get_package_distribution (self , package_name ):
306332 """
307333 This method returns the package distribution for a given package name if installed, or None otherwise.
0 commit comments