Skip to content

Commit de321c5

Browse files
authored
implement direct ml (#211)
* add directml gpu support * fix bug that doesn't allow models to be loaded on gpu * fix lock and import
1 parent fd37114 commit de321c5

6 files changed

Lines changed: 85 additions & 7 deletions

File tree

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,10 @@ or
490490
```sh
491491
poetry install --extras "gpu"
492492
```
493+
or
494+
```sh
495+
poetry install --extras "dml"
496+
```
493497
494498
### Running the Command-Line Interface Locally
495499

audio_separator/separator/architectures/mdxc_separator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ def load_model(self):
105105
else:
106106
self.logger.debug("Loading TFC_TDF_net model...")
107107
self.model_run = TFC_TDF_net(self.model_data_cfgdict, device=self.torch_device)
108-
self.model_run.load_state_dict(torch.load(self.model_path, map_location=self.torch_device))
108+
self.logger.debug("Loading model onto cpu")
109+
# For some reason loading the state onto a hardware accelerated devices causes issues,
110+
# so we load it onto CPU first then move it to the device
111+
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
109112
self.model_run.to(self.torch_device).eval()
110113

111114
except RuntimeError as e:

audio_separator/separator/architectures/vr_separator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def separate(self, audio_file_path, custom_output_names=None):
144144
self.logger.debug("Determining model capacity...")
145145
self.model_run = nets.determine_model_capacity(self.model_params.param["bins"] * 2, nn_arch_size)
146146

147-
self.model_run.load_state_dict(torch.load(self.model_path, map_location=self.torch_device_cpu))
147+
self.model_run.load_state_dict(torch.load(self.model_path, map_location="cpu"))
148148
self.model_run.to(self.torch_device)
149149
self.logger.debug("Model loaded and moved to device.")
150150

audio_separator/separator/separator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393
sample_rate=44100,
9494
use_soundfile=False,
9595
use_autocast=False,
96+
use_directml=False,
9697
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
9798
vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
9899
demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
@@ -179,6 +180,7 @@ def __init__(
179180

180181
self.use_soundfile = use_soundfile
181182
self.use_autocast = use_autocast
183+
self.use_directml = use_directml
182184

183185
# These are parameters which users may want to configure so we expose them to the top-level Separator class,
184186
# even though they are specific to a single model architecture
@@ -246,20 +248,24 @@ def log_onnxruntime_packages(self):
246248
onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
247249
onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon")
248250
onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
251+
onnxruntime_dml_package = self.get_package_distribution("onnxruntime-directml")
249252

250253
if onnxruntime_gpu_package is not None:
251254
self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}")
252255
if onnxruntime_silicon_package is not None:
253256
self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}")
254257
if onnxruntime_cpu_package is not None:
255258
self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}")
259+
if onnxruntime_dml_package is not None:
260+
self.logger.info(f"ONNX Runtime DirectML package installed with version: {onnxruntime_dml_package.version}")
256261

257262
def setup_torch_device(self, system_info):
258263
"""
259264
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
260265
"""
261266
hardware_acceleration_enabled = False
262267
ort_providers = ort.get_available_providers()
268+
has_torch_dml_installed = self.get_package_distribution("torch_directml")
263269

264270
self.torch_device_cpu = torch.device("cpu")
265271

@@ -269,6 +275,11 @@ def setup_torch_device(self, system_info):
269275
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
270276
self.configure_mps(ort_providers)
271277
hardware_acceleration_enabled = True
278+
elif self.use_directml and has_torch_dml_installed:
279+
import torch_directml
280+
if torch_directml.is_available():
281+
self.configure_dml(ort_providers)
282+
hardware_acceleration_enabled = True
272283

273284
if not hardware_acceleration_enabled:
274285
self.logger.info("No hardware acceleration could be configured, running in CPU mode")
@@ -302,6 +313,21 @@ def configure_mps(self, ort_providers):
302313
else:
303314
self.logger.warning("CoreMLExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
304315

316+
def configure_dml(self, ort_providers):
317+
"""
318+
This method configures the DirectML device for PyTorch and ONNX Runtime, if available.
319+
"""
320+
import torch_directml
321+
self.logger.info("DirectML is available in Torch, setting Torch device to DirectML")
322+
self.torch_device_dml = torch_directml.device()
323+
self.torch_device = self.torch_device_dml
324+
325+
if "DmlExecutionProvider" in ort_providers:
326+
self.logger.info("ONNXruntime has DmlExecutionProvider available, enabling acceleration")
327+
self.onnx_execution_provider = ["DmlExecutionProvider"]
328+
else:
329+
self.logger.warning("DmlExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
330+
305331
def get_package_distribution(self, package_name):
306332
"""
307333
This method returns the package distribution for a given package name if installed, or None otherwise.

poetry.lock

Lines changed: 47 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ librosa = ">=0.10"
3737
samplerate = "0.1.0"
3838
six = ">=1.16"
3939
torch = ">=2.3"
40+
torch_directml = {version = "*", optional = true}
4041
tqdm = "*"
4142
pydub = ">=0.25"
4243
audioop-lts = { version = ">=0.2.1", python = "^3.13" }
4344
onnx-weekly = { version = "*" }
4445
onnx2torch-py313 = ">=1.6"
4546
onnxruntime = { version = ">=1.17", optional = true }
4647
onnxruntime-gpu = { version = ">=1.17", optional = true }
48+
onnxruntime-directml = { version = ">=1.17", optional = true } # haven't tested different versions, but gonna assume 1.17, the same as others
4749
julius = ">=0.2"
4850
diffq-fixed = { version = ">=0.2", platform = "win32" }
4951
diffq = { version = ">=0.2", platform = "!=win32" }
@@ -58,6 +60,7 @@ scipy = "^1.13.0"
5860
[tool.poetry.extras]
5961
cpu = ["onnxruntime"]
6062
gpu = ["onnxruntime-gpu"]
63+
dml = ["onnxruntime-directml", "torch_directml"]
6164

6265
[tool.poetry.scripts]
6366
audio-separator = 'audio_separator.utils.cli:main'

0 commit comments

Comments
 (0)