Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ let noCudaCmlxExcludes = [
let cxxSettings: [CXXSetting]
let linkerSettings: [LinkerSetting]
let mlxSwiftExcludes: [String]
let cmlxPlugins: [Target.PluginUsage]

if Context.environment["SPM_CUDA"] != "0" {
// Linux with CUDA
Expand Down Expand Up @@ -140,6 +141,10 @@ let noCudaCmlxExcludes = [
"GPU+Metal.swift",
"MLXArray+Metal.swift",
]

cmlxPlugins = [
.plugin(name: "CudaBuild")
]
} else {
// Linux without CUDA (CPU only)

Expand Down Expand Up @@ -175,6 +180,10 @@ let noCudaCmlxExcludes = [
"MLXFast.swift",
"MLXFastKernel.swift",
]

cmlxPlugins = [
.plugin(name: "CudaBuild")
]
}
#else
// Apple's platforms with Metal
Expand Down Expand Up @@ -212,6 +221,11 @@ let noCudaCmlxExcludes = [
let mlxSwiftExcludes: [String] = [
"GPU+CUDA.swift"
]

let cmlxPlugins: [Target.PluginUsage] = [
.plugin(name: "CudaBuild"),
.plugin(name: "BuildSwiftPMMetalLibrary"),
]
#endif

let cmlx = Target.target(
Expand Down Expand Up @@ -289,9 +303,7 @@ let cmlx = Target.target(
.define("MLX_VERSION", to: "\"0.31.1\""),
],
linkerSettings: linkerSettings,
plugins: [
.plugin(name: "CudaBuild")
],
plugins: cmlxPlugins
)

let package = Package(
Expand Down Expand Up @@ -321,6 +333,10 @@ let package = Package(
],
targets: [
cmlx,
.plugin(
name: "BuildSwiftPMMetalLibrary",
capability: .buildTool()
),
.testTarget(
name: "CmlxTests",
dependencies: ["Cmlx"]
Expand Down
53 changes: 53 additions & 0 deletions Plugins/BuildSwiftPMMetalLibrary/plugin.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import Foundation
import PackagePlugin

@main
struct BuildSwiftPMMetalLibrary: BuildToolPlugin {
func createBuildCommands(context: PluginContext, target: any Target) async throws -> [Command] {
#if os(Linux)
return []
#else
let packageRoot = context.package.directoryURL
let script = packageRoot.appendingPathComponent("tools/build-swiftpm-metallib.sh")
let output = context.pluginWorkDirectoryURL.appendingPathComponent("default.metallib")

return [
.buildCommand(
displayName: "Build SwiftPM default.metallib",
executable: URL(fileURLWithPath: "/bin/bash"),
arguments: [script.path, output.path],
inputFiles: inputFiles(packageRoot: packageRoot, script: script),
outputFiles: [output]
)
]
#endif
}

#if !os(Linux)
private func inputFiles(packageRoot: URL, script: URL) -> [URL] {
let kernelsDirectory = packageRoot.appendingPathComponent(
"Source/Cmlx/mlx/mlx/backend/metal/kernels")
var files = [script]
files.append(contentsOf: recursivelyCollectedMetalInputs(in: kernelsDirectory))
return files
}

private func recursivelyCollectedMetalInputs(in directory: URL) -> [URL] {
let fileManager = FileManager.default
guard
let enumerator = fileManager.enumerator(
at: directory, includingPropertiesForKeys: nil)
else {
return []
}

return enumerator.compactMap { entry -> URL? in
guard let url = entry as? URL else { return nil }
guard url.pathExtension == "metal" || url.pathExtension == "h" else {
return nil
}
return url
}.sorted { $0.path < $1.path }
}
#endif
}
123 changes: 123 additions & 0 deletions tools/build-swiftpm-metallib.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/bin/bash
# Build the default Metal library resource used by SwiftPM Cmlx builds.

set -euo pipefail

if [[ $# -ne 1 ]]; then
echo "usage: $0 OUTPUT_METALLIB" >&2
exit 64
fi

OUTPUT="$1"
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
ROOT_DIR=$(realpath "${SCRIPT_DIR}/..")
KERNELS_DIR="${ROOT_DIR}/Source/Cmlx/mlx/mlx/backend/metal/kernels"

normalize_sdk_name() {
local raw="$1"
raw=$(basename "${raw}")
raw=$(printf '%s' "${raw}" | tr '[:upper:]' '[:lower:]')
case "${raw}" in
macosx*) echo "macosx" ;;
iphoneos*) echo "iphoneos" ;;
iphonesimulator*) echo "iphonesimulator" ;;
appletvos*) echo "appletvos" ;;
appletvsimulator*) echo "appletvsimulator" ;;
xros* | visionos*) echo "xros" ;;
xrsimulator* | visionsimulator*) echo "xrsimulator" ;;
*) echo "${raw}" ;;
esac
}

requested_sdk="${SDK_NAME:-${PLATFORM_NAME:-}}"
if [[ -z "${requested_sdk}" && -n "${SDKROOT:-}" ]]; then
requested_sdk=$(basename "${SDKROOT}")
fi
SDK=$(normalize_sdk_name "${requested_sdk:-macosx}")

case "${SDK}" in
macosx)
DEPLOYMENT_TARGET="${MACOSX_DEPLOYMENT_TARGET:-14.0}"
deployment_flag=("-mmacosx-version-min=${DEPLOYMENT_TARGET}")
;;
iphoneos | iphonesimulator)
DEPLOYMENT_TARGET="${IPHONEOS_DEPLOYMENT_TARGET:-${IOS_DEPLOYMENT_TARGET:-17.0}}"
deployment_flag=("-mios-version-min=${DEPLOYMENT_TARGET}")
;;
appletvos | appletvsimulator)
DEPLOYMENT_TARGET="${TVOS_DEPLOYMENT_TARGET:-17.0}"
deployment_flag=("-mtvos-version-min=${DEPLOYMENT_TARGET}")
;;
xros)
DEPLOYMENT_TARGET="${XROS_DEPLOYMENT_TARGET:-${VISIONOS_DEPLOYMENT_TARGET:-1.0}}"
deployment_flag=("-mtargetos=xros${DEPLOYMENT_TARGET}")
;;
xrsimulator)
DEPLOYMENT_TARGET="${XROS_DEPLOYMENT_TARGET:-${VISIONOS_DEPLOYMENT_TARGET:-1.0}}"
deployment_flag=("-mtargetos=xros${DEPLOYMENT_TARGET}-simulator")
;;
*)
echo "unsupported SDK '${SDK}'" >&2
exit 65
;;
esac

METAL=$(xcrun -sdk "${SDK}" -find metal)
METALLIB=$(xcrun -sdk "${SDK}" -find metallib)
TMP_DIR=$(mktemp -d)
trap 'rm -rf "${TMP_DIR}"' EXIT

metal_version=$(
printf '%s\n' '__METAL_VERSION__' |
"${METAL}" "${deployment_flag[@]}" -E -x metal -P - |
tail -1 |
tr -d '[:space:]'
)
metal_version=${metal_version:-0}

kernels=(
"arg_reduce"
"conv"
"gemv"
"layer_norm"
"random"
"rms_norm"
"rope"
"scaled_dot_product_attention"
)

if (( metal_version >= 320 )); then
kernels+=("fence")
fi

metal_flags=(
-x metal
-Wall
-Wextra
-fno-fast-math
-Wno-c++17-extensions
-Wno-c++20-extensions
"${deployment_flag[@]}"
)

if (( metal_version >= 400 )); then
metal_flags+=(-std=metal4.0)
elif (( metal_version >= 320 )); then
metal_flags+=(-std=metal3.2)
elif (( metal_version >= 310 )); then
metal_flags+=(-std=metal3.1)
elif (( metal_version >= 300 )); then
metal_flags+=(-std=metal3.0)
fi

air_files=()
for kernel in "${kernels[@]}"; do
source="${KERNELS_DIR}/${kernel}.metal"
air="${TMP_DIR}/${kernel}.air"
"${METAL}" "${metal_flags[@]}" -c "${source}" -I"${ROOT_DIR}/Source/Cmlx/mlx" -o "${air}"
air_files+=("${air}")
done

mkdir -p "$(dirname "${OUTPUT}")"
"${METALLIB}" "${air_files[@]}" -o "${TMP_DIR}/default.metallib"
mv "${TMP_DIR}/default.metallib" "${OUTPUT}"