Skip to content

Commit 214b58e

Browse files
authored
Merge pull request beehive-lab#103 from ArturSkowronski/feat/metal-backend-support
Add Apple Metal backend support
2 parents 6ac9e6b + 237ac97 commit 214b58e

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

llama-tornado

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ from enum import Enum
1919
class Backend(Enum):
2020
OPENCL = "opencl"
2121
PTX = "ptx"
22+
METAL = "metal"
2223

2324

2425
class LlamaRunner:
@@ -168,6 +169,14 @@ class LlamaRunner:
168169
"ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.ptx",
169170
]
170171
)
172+
elif args.backend == Backend.METAL:
173+
module_config.extend(
174+
[
175+
f"@{self.tornado_sdk}/etc/exportLists/metal-exports",
176+
"--add-modules",
177+
"ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.metal",
178+
]
179+
)
171180

172181
module_config.extend(
173182
[
@@ -410,6 +419,13 @@ def create_parser() -> argparse.ArgumentParser:
410419
const=Backend.PTX,
411420
help="Use PTX/CUDA backend",
412421
)
422+
hw_group.add_argument(
423+
"--metal",
424+
dest="backend",
425+
action="store_const",
426+
const=Backend.METAL,
427+
help="Use Apple Metal backend (macOS only, requires TornadoVM 4.0+)",
428+
)
413429
hw_group.add_argument("--gpu-memory", default="14GB", help="GPU memory allocation")
414430
hw_group.add_argument("--heap-min", default="20g", help="Minimum JVM heap size")
415431
hw_group.add_argument("--heap-max", default="20g", help="Maximum JVM heap size")

0 commit comments

Comments
 (0)