Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The swift-tools-version declares the minimum version of Swift required to build this package.
// Copyright © 2024 Apple Inc.

import CompilerPluginSupport
import PackageDescription

#if os(Linux)
Expand Down Expand Up @@ -232,7 +233,8 @@ let package = Package(
],
dependencies: [
// for Complex type
.package(url: "https://github.com/apple/swift-numerics", from: "1.0.0")
.package(url: "https://github.com/apple/swift-numerics", from: "1.0.0"),
.package(url: "https://github.com/swiftlang/swift-syntax.git", from: "602.0.0"),
],
targets: [
cmlx,
Expand All @@ -246,12 +248,23 @@ let package = Package(
dependencies: [
"Cmlx",
.product(name: "Numerics", package: "swift-numerics"),
"MLXMacrosPlugin",
],
exclude: mlxSwiftExcludes,
swiftSettings: [
.enableExperimentalFeature("StrictConcurrency")
]
),
.macro(
name: "MLXMacrosPlugin",
dependencies: [
.product(name: "SwiftCompilerPlugin", package: "swift-syntax"),
.product(name: "SwiftSyntax", package: "swift-syntax"),
.product(name: "SwiftSyntaxBuilder", package: "swift-syntax"),
.product(name: "SwiftSyntaxMacros", package: "swift-syntax"),
.product(name: "SwiftDiagnostics", package: "swift-syntax"),
]
),
.target(
name: "MLXRandom",
dependencies: ["MLX"],
Expand Down Expand Up @@ -301,6 +314,13 @@ let package = Package(
"MLX", "MLXNN", "MLXOptimizers",
]
),
.testTarget(
name: "MLXMacrosTests",
dependencies: [
"MLXMacrosPlugin",
.product(name: "SwiftSyntaxMacrosTestSupport", package: "swift-syntax"),
]
),

// ------
// Example programs
Expand Down
37 changes: 37 additions & 0 deletions Source/MLX/Documentation.docc/Organization/initialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,43 @@ When creating using an array or sequence you can also control the shape:
let v1 = MLXArray(0 ..< 12, [3, 4])
```

### Macro Literals

You can also create arrays from nested literals with the `#MLXArray` expression macro:

```swift
import MLX

let a = #MLXArray([[1, 2], [3, 4]])
let b = #MLXArray([[1, 2], [3, 4]], dtype: .int16)
let c = #MLXArray([[[0.1, 0.2], [0.3, 0.4]]], dtype: .float16)
let d = #MLXArray([[true, false], [false, true]])
let e = #MLXArray([[0, 1], [1, 0]], dtype: .bool)
```

This is especially convenient for small constants in model code and tests.
The macro requires rectangular nested arrays with boolean, integer, or floating-point literals.
Mixing boolean and numeric literals in the same nested literal is not supported.
Mixed numeric literals are promoted to floating-point behavior: if any element is a
floating-point literal, the entire literal is treated as floating-point.

When `dtype` is a known integer dtype (for example `.int16`, `.int64`, `.uint8`) or `.float32`,
the expansion emits typed Swift literals directly and avoids a trailing `.asType(...)` cast.
For dynamic dtype expressions, or dtypes that do not map cleanly to a Swift literal type
(for example `.float16`, `.bfloat16`, `.complex64`), the macro emits a base array and applies
`.asType(...)`.

If `dtype` is an integer type and the literal contains floating-point values,
the macro still allows expansion but emits a warning because conversion may truncate.

```swift
// promoted to floating-point because of 2.5
let mixed = #MLXArray([[1, 2.5], [3, 4]])
```

Boolean literals are supported directly (`true` / `false`). For `dtype: .bool`,
integer literals are accepted only when each element is `0` or `1`.

### Random Value Arrays

See also `MLXRandom` for creating arrays with random data.
Expand Down
21 changes: 21 additions & 0 deletions Source/MLX/MLXMacros.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright © 2026 Apple Inc.

/// Construct an ``MLXArray`` from a nested literal.
///
/// Examples:
///
/// ```swift
/// let a = #MLXArray([[1, 2, 3], [4, 5, 6]])
/// let b = #MLXArray([[1, 2, 3], [4, 5, 6]], dtype: .int16)
/// let c = #MLXArray([[0.1, 0.2], [0.3, 0.4]], dtype: .float16)
/// ```
@freestanding(expression)
public macro MLXArray(_ literal: Any) -> MLXArray =
#externalMacro(
module: "MLXMacrosPlugin", type: "MLXLiteralMacro")

/// Construct an ``MLXArray`` from a nested literal and cast to `dtype`.
@freestanding(expression)
public macro MLXArray(_ literal: Any, dtype: DType) -> MLXArray =
#externalMacro(
module: "MLXMacrosPlugin", type: "MLXLiteralMacro")
Loading
Loading