From 2de04403d8a41f456d4f47dbd182c98e3a9ab566 Mon Sep 17 00:00:00 2001 From: Robert Sale Date: Tue, 10 Feb 2026 06:23:46 -0800 Subject: [PATCH 1/7] Convenience Initializer Macro for MLXArray --- Package.swift | 22 +- .../Organization/initialization.md | 21 ++ Source/MLX/MLXMacros.swift | 21 ++ Source/MLXMacrosPlugin/MLXLiteralMacro.swift | 236 ++++++++++++++++++ Source/MLXMacrosPlugin/MLXMacrosPlugin.swift | 11 + .../MLXMacrosTests/MLXLiteralMacroTests.swift | 130 ++++++++++ 6 files changed, 440 insertions(+), 1 deletion(-) create mode 100644 Source/MLX/MLXMacros.swift create mode 100644 Source/MLXMacrosPlugin/MLXLiteralMacro.swift create mode 100644 Source/MLXMacrosPlugin/MLXMacrosPlugin.swift create mode 100644 Tests/MLXMacrosTests/MLXLiteralMacroTests.swift diff --git a/Package.swift b/Package.swift index e1b2d2e2..dea27200 100644 --- a/Package.swift +++ b/Package.swift @@ -2,6 +2,7 @@ // The swift-tools-version declares the minimum version of Swift required to build this package. // Copyright © 2024 Apple Inc. +import CompilerPluginSupport import PackageDescription #if os(Linux) @@ -232,7 +233,8 @@ let package = Package( ], dependencies: [ // for Complex type - .package(url: "https://github.com/apple/swift-numerics", from: "1.0.0") + .package(url: "https://github.com/apple/swift-numerics", from: "1.0.0"), + .package(url: "https://github.com/swiftlang/swift-syntax.git", from: "602.0.0"), ], targets: [ cmlx, @@ -246,12 +248,23 @@ let package = Package( dependencies: [ "Cmlx", .product(name: "Numerics", package: "swift-numerics"), + "MLXMacrosPlugin", ], exclude: mlxSwiftExcludes, swiftSettings: [ .enableExperimentalFeature("StrictConcurrency") ] ), + .macro( + name: "MLXMacrosPlugin", + dependencies: [ + .product(name: "SwiftCompilerPlugin", package: "swift-syntax"), + .product(name: "SwiftSyntax", package: "swift-syntax"), + .product(name: "SwiftSyntaxBuilder", package: "swift-syntax"), + .product(name: "SwiftSyntaxMacros", package: "swift-syntax"), + .product(name: "SwiftDiagnostics", package: "swift-syntax"), + ] + ), .target( name: "MLXRandom", dependencies: ["MLX"], @@ -301,6 +314,13 @@ let package = Package( "MLX", "MLXNN", "MLXOptimizers", ] ), + .testTarget( + name: "MLXMacrosTests", + dependencies: [ + "MLXMacrosPlugin", + .product(name: "SwiftSyntaxMacrosTestSupport", package: "swift-syntax"), + ] + ), // ------ // Example programs diff --git a/Source/MLX/Documentation.docc/Organization/initialization.md b/Source/MLX/Documentation.docc/Organization/initialization.md index 5e9e82e2..4c6b6ad9 100644 --- a/Source/MLX/Documentation.docc/Organization/initialization.md +++ b/Source/MLX/Documentation.docc/Organization/initialization.md @@ -141,6 +141,27 @@ When creating using an array or sequence you can also control the shape: let v1 = MLXArray(0 ..< 12, [3, 4]) ``` +### Macro Literals + +You can also create arrays from nested literals with the `#mlx` expression macro: + +```swift +import MLX + +let a = #mlx([[1, 2], [3, 4]]) +let b = #mlx([[1, 2], [3, 4]], dtype: .int16) +let c = #mlx([[[0.1, 0.2], [0.3, 0.4]]], dtype: .float16) +``` + +This is especially convenient for small constants in model code and tests. +The macro requires rectangular nested arrays and numeric literals. + +When `dtype` is a known integer dtype (for example `.int16`, `.int64`, `.uint8`) or `.float32`, +the expansion emits typed Swift literals directly and avoids a trailing `.asType(...)` cast. +For dynamic dtype expressions, or dtypes that do not map cleanly to a Swift literal type +(for example `.float16`, `.bfloat16`, `.complex64`), the macro emits a base array and applies +`.asType(...)`. + ### Random Value Arrays See also `MLXRandom` for creating arrays with random data. diff --git a/Source/MLX/MLXMacros.swift b/Source/MLX/MLXMacros.swift new file mode 100644 index 00000000..be72cab9 --- /dev/null +++ b/Source/MLX/MLXMacros.swift @@ -0,0 +1,21 @@ +// Copyright © 2026 Apple Inc. + +/// Construct an ``MLXArray`` from a nested numeric literal. +/// +/// Examples: +/// +/// ```swift +/// let a = #mlx([[1, 2, 3], [4, 5, 6]]) +/// let b = #mlx([[1, 2, 3], [4, 5, 6]], dtype: .int16) +/// let c = #mlx([[0.1, 0.2], [0.3, 0.4]], dtype: .float16) +/// ``` +@freestanding(expression) +public macro mlx(_ literal: Any) -> MLXArray = + #externalMacro( + module: "MLXMacrosPlugin", type: "MLXLiteralMacro") + +/// Construct an ``MLXArray`` from a nested numeric literal and cast to `dtype`. +@freestanding(expression) +public macro mlx(_ literal: Any, dtype: DType) -> MLXArray = + #externalMacro( + module: "MLXMacrosPlugin", type: "MLXLiteralMacro") diff --git a/Source/MLXMacrosPlugin/MLXLiteralMacro.swift b/Source/MLXMacrosPlugin/MLXLiteralMacro.swift new file mode 100644 index 00000000..4b6ec0a6 --- /dev/null +++ b/Source/MLXMacrosPlugin/MLXLiteralMacro.swift @@ -0,0 +1,236 @@ +// Copyright © 2026 Apple Inc. + +import SwiftDiagnostics +import SwiftSyntax +import SwiftSyntaxBuilder +import SwiftSyntaxMacros + +private enum ScalarKind { + case int + case float + + static func merge(_ lhs: ScalarKind, _ rhs: ScalarKind) -> ScalarKind { + if lhs == .float || rhs == .float { + return .float + } + return .int + } +} + +private struct ParsedLiteral { + var flat: [ExprSyntax] + var shape: [Int] + var kind: ScalarKind +} + +private struct MacroError: Error {} + +private struct MacroMessage: DiagnosticMessage { + let message: String + let diagnosticID: MessageID + let severity: DiagnosticSeverity = .error + + init(_ message: String) { + self.message = message + self.diagnosticID = MessageID(domain: "MLXMacros", id: "mlx_literal") + } +} + +private enum KnownDType: String { + case bool + case uint8 + case uint16 + case uint32 + case uint64 + case int8 + case int16 + case int32 + case int64 + case float16 + case float32 + case bfloat16 + case complex64 + case float64 +} + +public struct MLXLiteralMacro: ExpressionMacro { + public static func expansion( + of node: some FreestandingMacroExpansionSyntax, + in context: some MacroExpansionContext + ) throws -> ExprSyntax { + let args = Array(node.arguments) + guard let literalArg = args.first else { + diagnose("#mlx requires a nested numeric array literal.", at: Syntax(node), in: context) + return "MLXArray([])" + } + + let dtypeExpr: ExprSyntax? + if args.count == 1 { + dtypeExpr = nil + } else if args.count == 2 { + guard args[1].label?.text == "dtype" else { + diagnose( + "#mlx second argument must be labeled 'dtype:'.", + at: Syntax(args[1]), in: context) + return "MLXArray([])" + } + dtypeExpr = args[1].expression + } else { + diagnose( + "#mlx accepts one literal argument and optional dtype:.", at: Syntax(node), + in: context) + return "MLXArray([])" + } + + let parsed: ParsedLiteral + do { + parsed = try parseLiteral(literalArg.expression, context: context) + } catch { + return "MLXArray([])" + } + + let flatSource = parsed.flat.map { $0.description }.joined(separator: ", ") + let shapeSource = parsed.shape.map(String.init).joined(separator: ", ") + let baseExpr: ExprSyntax = + switch parsed.kind { + case .int: + "MLXArray([\(raw: flatSource)], [\(raw: shapeSource)])" + case .float: + "MLXArray(converting: [\(raw: flatSource)], [\(raw: shapeSource)])" + } + + if let dtypeExpr { + if let knownDType = parseKnownDType(dtypeExpr), + let typedExpr = makeTypedExpression(parsed: parsed, dtype: knownDType) + { + return typedExpr + } + return "\(baseExpr).asType(\(dtypeExpr))" + } else { + return baseExpr + } + } + + private static func parseKnownDType(_ expr: ExprSyntax) -> KnownDType? { + guard let member = expr.as(MemberAccessExprSyntax.self) else { + return nil + } + return KnownDType(rawValue: member.declName.baseName.text) + } + + private static func makeTypedExpression(parsed: ParsedLiteral, dtype: KnownDType) -> ExprSyntax? + { + let shapeSource = parsed.shape.map(String.init).joined(separator: ", ") + let typedFlat: String + + switch dtype { + case .int8: + guard parsed.kind == .int else { return nil } + typedFlat = wrap(parsed.flat, with: "Int8") + case .int16: + guard parsed.kind == .int else { return nil } + typedFlat = wrap(parsed.flat, with: "Int16") + case .int32: + guard parsed.kind == .int else { return nil } + typedFlat = wrap(parsed.flat, with: "Int32") + case .int64: + guard parsed.kind == .int else { return nil } + typedFlat = wrap(parsed.flat, with: "Int64") + case .uint8: + guard parsed.kind == .int else { return nil } + typedFlat = wrap(parsed.flat, with: "UInt8") + case .uint16: + guard parsed.kind == .int else { return nil } + typedFlat = wrap(parsed.flat, with: "UInt16") + case .uint32: + guard parsed.kind == .int else { return nil } + typedFlat = wrap(parsed.flat, with: "UInt32") + case .uint64: + guard parsed.kind == .int else { return nil } + typedFlat = wrap(parsed.flat, with: "UInt64") + case .float32: + if parsed.kind == .int { + typedFlat = wrap(parsed.flat, with: "Float") + } else { + typedFlat = wrap(parsed.flat, with: "Float") + } + case .bool, .float16, .bfloat16, .complex64, .float64: + return nil + } + + return "MLXArray([\(raw: typedFlat)], [\(raw: shapeSource)])" + } + + private static func wrap(_ values: [ExprSyntax], with typeName: String) -> String { + values.map { "\(typeName)(\($0))" }.joined(separator: ", ") + } + + private static func parseLiteral( + _ expr: ExprSyntax, context: some MacroExpansionContext + ) throws -> ParsedLiteral { + if let arrayExpr = expr.as(ArrayExprSyntax.self) { + if arrayExpr.elements.isEmpty { + return ParsedLiteral(flat: [], shape: [0], kind: .int) + } + + var children: [ParsedLiteral] = [] + children.reserveCapacity(arrayExpr.elements.count) + + for element in arrayExpr.elements { + children.append(try parseLiteral(element.expression, context: context)) + } + + let firstShape = children[0].shape + if children.dropFirst().contains(where: { $0.shape != firstShape }) { + diagnose( + "#mlx does not support ragged nested arrays.", at: Syntax(expr), in: context) + throw MacroError() + } + + let kind = children.dropFirst().reduce(children[0].kind) { + ScalarKind.merge($0, $1.kind) + } + + return ParsedLiteral( + flat: children.flatMap(\.flat), shape: [children.count] + firstShape, kind: kind) + } + + if isInteger(expr) { + return ParsedLiteral(flat: [expr], shape: [], kind: .int) + } + if isFloat(expr) { + return ParsedLiteral(flat: [expr], shape: [], kind: .float) + } + + diagnose( + "#mlx only supports integer and floating-point literals.", at: Syntax(expr), in: context + ) + throw MacroError() + } + + private static func isInteger(_ expr: ExprSyntax) -> Bool { + if expr.as(IntegerLiteralExprSyntax.self) != nil { + return true + } + if let prefix = expr.as(PrefixOperatorExprSyntax.self) { + return isInteger(prefix.expression) + } + return false + } + + private static func isFloat(_ expr: ExprSyntax) -> Bool { + if expr.as(FloatLiteralExprSyntax.self) != nil { + return true + } + if let prefix = expr.as(PrefixOperatorExprSyntax.self) { + return isFloat(prefix.expression) + } + return false + } + + private static func diagnose( + _ message: String, at node: Syntax, in context: some MacroExpansionContext + ) { + context.diagnose(Diagnostic(node: node, message: MacroMessage(message))) + } +} diff --git a/Source/MLXMacrosPlugin/MLXMacrosPlugin.swift b/Source/MLXMacrosPlugin/MLXMacrosPlugin.swift new file mode 100644 index 00000000..41ee077f --- /dev/null +++ b/Source/MLXMacrosPlugin/MLXMacrosPlugin.swift @@ -0,0 +1,11 @@ +// Copyright © 2026 Apple Inc. + +import SwiftCompilerPlugin +import SwiftSyntaxMacros + +@main +struct MLXMacrosPlugin: CompilerPlugin { + let providingMacros: [Macro.Type] = [ + MLXLiteralMacro.self + ] +} diff --git a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift new file mode 100644 index 00000000..22bf3617 --- /dev/null +++ b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift @@ -0,0 +1,130 @@ +// Copyright © 2026 Apple Inc. + +import SwiftSyntaxMacros +import SwiftSyntaxMacrosTestSupport +import XCTest + +@testable import MLXMacrosPlugin + +private let testMacros: [String: Macro.Type] = [ + "mlx": MLXLiteralMacro.self +] + +final class MLXLiteralMacroTests: XCTestCase { + func testExpandsIntegerLiteral() { + assertMacroExpansion( + "#mlx([[1, 2], [3, 4]])", + expandedSource: "MLXArray([1, 2, 3, 4], [2, 2])", + macros: testMacros + ) + } + + func testExpandsFloatLiteral() { + assertMacroExpansion( + "#mlx([[0.1, 0.2], [0.3, 0.4]])", + expandedSource: "MLXArray(converting: [0.1, 0.2, 0.3, 0.4], [2, 2])", + macros: testMacros + ) + } + + func testExpandsWithDtypeCast() { + assertMacroExpansion( + "#mlx([[1, 2], [3, 4]], dtype: .int16)", + expandedSource: "MLXArray([Int16(1), Int16(2), Int16(3), Int16(4)], [2, 2])", + macros: testMacros + ) + } + + func testExpandsIntegerLiteralWithInt64Dtype() { + assertMacroExpansion( + "#mlx([[1, 2], [3, 4]], dtype: .int64)", + expandedSource: "MLXArray([Int64(1), Int64(2), Int64(3), Int64(4)], [2, 2])", + macros: testMacros + ) + } + + func testExpandsIntegerLiteralWithUInt8Dtype() { + assertMacroExpansion( + "#mlx([[1, 2], [3, 4]], dtype: .uint8)", + expandedSource: "MLXArray([UInt8(1), UInt8(2), UInt8(3), UInt8(4)], [2, 2])", + macros: testMacros + ) + } + + func testExpandsIntegerLiteralWithFloat32Dtype() { + assertMacroExpansion( + "#mlx([[1, 2], [3, 4]], dtype: .float32)", + expandedSource: "MLXArray([Float(1), Float(2), Float(3), Float(4)], [2, 2])", + macros: testMacros + ) + } + + func testFallsBackToAsTypeForFloat64Dtype() { + assertMacroExpansion( + "#mlx([[1.0, 2.0], [3.0, 4.0]], dtype: .float64)", + expandedSource: "MLXArray(converting: [1.0, 2.0, 3.0, 4.0], [2, 2]).asType(.float64)", + macros: testMacros + ) + } + + func testFallsBackToAsTypeForDynamicDtypeExpression() { + assertMacroExpansion( + "#mlx([[1, 2], [3, 4]], dtype: dtypeValue)", + expandedSource: "MLXArray([1, 2, 3, 4], [2, 2]).asType(dtypeValue)", + macros: testMacros + ) + } + + func testExpandsThreeDimensionalIntegerLiteral() { + assertMacroExpansion( + "#mlx([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])", + expandedSource: "MLXArray([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2])", + macros: testMacros + ) + } + + func testExpandsFourDimensionalFloatLiteral() { + assertMacroExpansion( + "#mlx([[[[0.1, 0.2]], [[0.3, 0.4]]], [[[0.5, 0.6]], [[0.7, 0.8]]]])", + expandedSource: + "MLXArray(converting: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], [2, 2, 1, 2])", + macros: testMacros + ) + } + + func testExpandsMixedIntegerFloatLiteralAsFloat() { + assertMacroExpansion( + "#mlx([[1, 2.5], [3, 4.5]])", + expandedSource: "MLXArray(converting: [1, 2.5, 3, 4.5], [2, 2])", + macros: testMacros + ) + } + + func testExpandsDeepLiteralWithFloat16Dtype() { + assertMacroExpansion( + "#mlx([[[1, 2], [3, 4]]], dtype: .float16)", + expandedSource: "MLXArray([1, 2, 3, 4], [1, 2, 2]).asType(.float16)", + macros: testMacros + ) + } + + func testExpandsMixedLiteralWithInt8Dtype() { + assertMacroExpansion( + "#mlx([[1.25, 2], [3.5, 4]], dtype: .int8)", + expandedSource: "MLXArray(converting: [1.25, 2, 3.5, 4], [2, 2]).asType(.int8)", + macros: testMacros + ) + } + + func testRaggedLiteralDiagnostics() { + assertMacroExpansion( + "#mlx([[1, 2], [3]])", + expandedSource: "MLXArray([])", + diagnostics: [ + DiagnosticSpec( + message: "#mlx does not support ragged nested arrays.", line: 1, column: 6) + ], + macros: testMacros + ) + } +} From ceb476bfcfed5a7d67e5c57936086ce6bcf8ab1a Mon Sep 17 00:00:00 2001 From: Robert Sale Date: Tue, 10 Feb 2026 06:34:58 -0800 Subject: [PATCH 2/7] Single float promotes whole literal --- .../MLX/Documentation.docc/Organization/initialization.md | 7 +++++++ Tests/MLXMacrosTests/MLXLiteralMacroTests.swift | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/Source/MLX/Documentation.docc/Organization/initialization.md b/Source/MLX/Documentation.docc/Organization/initialization.md index 4c6b6ad9..ffbffcef 100644 --- a/Source/MLX/Documentation.docc/Organization/initialization.md +++ b/Source/MLX/Documentation.docc/Organization/initialization.md @@ -155,6 +155,8 @@ let c = #mlx([[[0.1, 0.2], [0.3, 0.4]]], dtype: .float16) This is especially convenient for small constants in model code and tests. The macro requires rectangular nested arrays and numeric literals. +Mixed numeric literals are promoted to floating-point behavior: if any element is a +floating-point literal, the entire literal is treated as floating-point. When `dtype` is a known integer dtype (for example `.int16`, `.int64`, `.uint8`) or `.float32`, the expansion emits typed Swift literals directly and avoids a trailing `.asType(...)` cast. @@ -162,6 +164,11 @@ For dynamic dtype expressions, or dtypes that do not map cleanly to a Swift lite (for example `.float16`, `.bfloat16`, `.complex64`), the macro emits a base array and applies `.asType(...)`. +```swift +// promoted to floating-point because of 2.5 +let mixed = #mlx([[1, 2.5], [3, 4]]) +``` + ### Random Value Arrays See also `MLXRandom` for creating arrays with random data. diff --git a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift index 22bf3617..dde4f680 100644 --- a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift +++ b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift @@ -100,6 +100,14 @@ final class MLXLiteralMacroTests: XCTestCase { ) } + func testExpandsSingleFloatElementAsFloatLiteral() { + assertMacroExpansion( + "#mlx([[1, 2], [3, 4.5]])", + expandedSource: "MLXArray(converting: [1, 2, 3, 4.5], [2, 2])", + macros: testMacros + ) + } + func testExpandsDeepLiteralWithFloat16Dtype() { assertMacroExpansion( "#mlx([[[1, 2], [3, 4]]], dtype: .float16)", From c26216fbd46793844f0ee7d18a0bdfb05089fbc3 Mon Sep 17 00:00:00 2001 From: Robert Sale Date: Tue, 10 Feb 2026 06:52:27 -0800 Subject: [PATCH 3/7] Added warning when mixing integer dtype with float literals --- Source/MLXMacrosPlugin/MLXLiteralMacro.swift | 36 ++++++++++++++----- .../MLXMacrosTests/MLXLiteralMacroTests.swift | 24 +++++++++++++ 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/Source/MLXMacrosPlugin/MLXLiteralMacro.swift b/Source/MLXMacrosPlugin/MLXLiteralMacro.swift index 4b6ec0a6..8e014531 100644 --- a/Source/MLXMacrosPlugin/MLXLiteralMacro.swift +++ b/Source/MLXMacrosPlugin/MLXLiteralMacro.swift @@ -28,10 +28,11 @@ private struct MacroError: Error {} private struct MacroMessage: DiagnosticMessage { let message: String let diagnosticID: MessageID - let severity: DiagnosticSeverity = .error + let severity: DiagnosticSeverity - init(_ message: String) { + init(_ message: String, severity: DiagnosticSeverity = .error) { self.message = message + self.severity = severity self.diagnosticID = MessageID(domain: "MLXMacros", id: "mlx_literal") } } @@ -100,10 +101,17 @@ public struct MLXLiteralMacro: ExpressionMacro { } if let dtypeExpr { - if let knownDType = parseKnownDType(dtypeExpr), - let typedExpr = makeTypedExpression(parsed: parsed, dtype: knownDType) - { - return typedExpr + if let knownDType = parseKnownDType(dtypeExpr) { + if isIntegerDType(knownDType), let floatExpr = parsed.flat.first(where: isFloat) { + diagnose( + "#mlx integer dtype with floating-point literal(s) may truncate values during conversion.", + at: Syntax(floatExpr), + severity: .warning, + in: context) + } + if let typedExpr = makeTypedExpression(parsed: parsed, dtype: knownDType) { + return typedExpr + } } return "\(baseExpr).asType(\(dtypeExpr))" } else { @@ -165,6 +173,15 @@ public struct MLXLiteralMacro: ExpressionMacro { values.map { "\(typeName)(\($0))" }.joined(separator: ", ") } + private static func isIntegerDType(_ dtype: KnownDType) -> Bool { + switch dtype { + case .uint8, .uint16, .uint32, .uint64, .int8, .int16, .int32, .int64: + return true + default: + return false + } + } + private static func parseLiteral( _ expr: ExprSyntax, context: some MacroExpansionContext ) throws -> ParsedLiteral { @@ -229,8 +246,11 @@ public struct MLXLiteralMacro: ExpressionMacro { } private static func diagnose( - _ message: String, at node: Syntax, in context: some MacroExpansionContext + _ message: String, + at node: Syntax, + severity: DiagnosticSeverity = .error, + in context: some MacroExpansionContext ) { - context.diagnose(Diagnostic(node: node, message: MacroMessage(message))) + context.diagnose(Diagnostic(node: node, message: MacroMessage(message, severity: severity))) } } diff --git a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift index dde4f680..b657de6e 100644 --- a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift +++ b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift @@ -120,6 +120,30 @@ final class MLXLiteralMacroTests: XCTestCase { assertMacroExpansion( "#mlx([[1.25, 2], [3.5, 4]], dtype: .int8)", expandedSource: "MLXArray(converting: [1.25, 2, 3.5, 4], [2, 2]).asType(.int8)", + diagnostics: [ + DiagnosticSpec( + message: + "#mlx integer dtype with floating-point literal(s) may truncate values during conversion.", + line: 1, + column: 8, + severity: .warning) + ], + macros: testMacros + ) + } + + func testWarnsOnIntegerDtypeWithFloatLiteral() { + assertMacroExpansion( + "#mlx([[1, 2.5], [3, 4]], dtype: .int16)", + expandedSource: "MLXArray(converting: [1, 2.5, 3, 4], [2, 2]).asType(.int16)", + diagnostics: [ + DiagnosticSpec( + message: + "#mlx integer dtype with floating-point literal(s) may truncate values during conversion.", + line: 1, + column: 11, + severity: .warning) + ], macros: testMacros ) } From 42b31d1c46c8b6a471595dd0125a025ff3d873e6 Mon Sep 17 00:00:00 2001 From: Robert Sale Date: Tue, 10 Feb 2026 07:29:45 -0800 Subject: [PATCH 4/7] Added support for true/false 0/1 and clarifying code comments --- .../Organization/initialization.md | 5 + Source/MLXMacrosPlugin/MLXLiteralMacro.swift | 132 ++++++++++++++++-- .../MLXMacrosTests/MLXLiteralMacroTests.swift | 44 ++++++ 3 files changed, 169 insertions(+), 12 deletions(-) diff --git a/Source/MLX/Documentation.docc/Organization/initialization.md b/Source/MLX/Documentation.docc/Organization/initialization.md index ffbffcef..b79b95d3 100644 --- a/Source/MLX/Documentation.docc/Organization/initialization.md +++ b/Source/MLX/Documentation.docc/Organization/initialization.md @@ -151,6 +151,8 @@ import MLX let a = #mlx([[1, 2], [3, 4]]) let b = #mlx([[1, 2], [3, 4]], dtype: .int16) let c = #mlx([[[0.1, 0.2], [0.3, 0.4]]], dtype: .float16) +let d = #mlx([[true, false], [false, true]]) +let e = #mlx([[0, 1], [1, 0]], dtype: .bool) ``` This is especially convenient for small constants in model code and tests. @@ -169,6 +171,9 @@ For dynamic dtype expressions, or dtypes that do not map cleanly to a Swift lite let mixed = #mlx([[1, 2.5], [3, 4]]) ``` +Boolean literals are supported directly (`true` / `false`). For `dtype: .bool`, +integer literals are accepted only when each element is `0` or `1`. + ### Random Value Arrays See also `MLXRandom` for creating arrays with random data. diff --git a/Source/MLXMacrosPlugin/MLXLiteralMacro.swift b/Source/MLXMacrosPlugin/MLXLiteralMacro.swift index 8e014531..a367ab26 100644 --- a/Source/MLXMacrosPlugin/MLXLiteralMacro.swift +++ b/Source/MLXMacrosPlugin/MLXLiteralMacro.swift @@ -6,14 +6,23 @@ import SwiftSyntaxBuilder import SwiftSyntaxMacros private enum ScalarKind { + case bool case int case float - static func merge(_ lhs: ScalarKind, _ rhs: ScalarKind) -> ScalarKind { - if lhs == .float || rhs == .float { + static func merge(_ lhs: ScalarKind, _ rhs: ScalarKind) -> ScalarKind? { + switch (lhs, rhs) { + case (.bool, .bool): + return .bool + case (.int, .int): + return .int + case (.float, .float): + return .float + case (.int, .float), (.float, .int): return .float + case (.bool, .int), (.int, .bool), (.bool, .float), (.float, .bool): + return nil } - return .int } } @@ -92,8 +101,13 @@ public struct MLXLiteralMacro: ExpressionMacro { let flatSource = parsed.flat.map { $0.description }.joined(separator: ", ") let shapeSource = parsed.shape.map(String.init).joined(separator: ", ") + // Default lowering path: + // - integer-only literals use MLXArray([Int...], shape) + // - any float literal promotes the whole literal to converting:[Double...] let baseExpr: ExprSyntax = switch parsed.kind { + case .bool: + "MLXArray([\(raw: flatSource)], [\(raw: shapeSource)])" case .int: "MLXArray([\(raw: flatSource)], [\(raw: shapeSource)])" case .float: @@ -102,6 +116,37 @@ public struct MLXLiteralMacro: ExpressionMacro { if let dtypeExpr { if let knownDType = parseKnownDType(dtypeExpr) { + if knownDType == .bool { + switch parsed.kind { + case .bool: + return baseExpr + case .int: + var boolValues: [String] = [] + boolValues.reserveCapacity(parsed.flat.count) + for element in parsed.flat { + guard let value = integerLiteralValue(element), value == 0 || value == 1 + else { + diagnose( + "#mlx dtype .bool only supports integer literals 0 or 1.", + at: Syntax(element), + in: context) + return "MLXArray([])" + } + boolValues.append(value == 1 ? "true" : "false") + } + let boolSource = boolValues.joined(separator: ", ") + return "MLXArray([\(raw: boolSource)], [\(raw: shapeSource)])" + case .float: + diagnose( + "#mlx dtype .bool only supports true/false literals or integer 0/1.", + at: Syntax(dtypeExpr), + in: context) + return "MLXArray([])" + } + } + // Keep explicit integer dtypes permissive but visible: + // if callers write float literals with an integer dtype, emit a warning + // since runtime conversion may truncate. if isIntegerDType(knownDType), let floatExpr = parsed.flat.first(where: isFloat) { diagnose( "#mlx integer dtype with floating-point literal(s) may truncate values during conversion.", @@ -109,10 +154,14 @@ public struct MLXLiteralMacro: ExpressionMacro { severity: .warning, in: context) } + // Fast path for dtypes we can materialize directly as Swift literals. + // This avoids emitting a trailing `.asType(...)` cast op. if let typedExpr = makeTypedExpression(parsed: parsed, dtype: knownDType) { return typedExpr } } + // Fallback for dynamic dtype expressions and dtypes that do not map cleanly + // to a concrete Swift literal representation. return "\(baseExpr).asType(\(dtypeExpr))" } else { return baseExpr @@ -132,6 +181,8 @@ public struct MLXLiteralMacro: ExpressionMacro { let typedFlat: String switch dtype { + case .bool: + return nil case .int8: guard parsed.kind == .int else { return nil } typedFlat = wrap(parsed.flat, with: "Int8") @@ -157,12 +208,12 @@ public struct MLXLiteralMacro: ExpressionMacro { guard parsed.kind == .int else { return nil } typedFlat = wrap(parsed.flat, with: "UInt64") case .float32: - if parsed.kind == .int { - typedFlat = wrap(parsed.flat, with: "Float") - } else { - typedFlat = wrap(parsed.flat, with: "Float") - } - case .bool, .float16, .bfloat16, .complex64, .float64: + // Float32 has a stable, direct Swift literal representation. + // Emit typed elements instead of base+cast for lower graph overhead. + typedFlat = wrap(parsed.flat, with: "Float") + case .float16, .bfloat16, .complex64, .float64: + // These currently rely on base+cast to keep expansion predictable + // across targets and avoid lossy/ambiguous literal synthesis. return nil } @@ -187,6 +238,7 @@ public struct MLXLiteralMacro: ExpressionMacro { ) throws -> ParsedLiteral { if let arrayExpr = expr.as(ArrayExprSyntax.self) { if arrayExpr.elements.isEmpty { + // Keep empty arrays legal and representable at compile time. return ParsedLiteral(flat: [], shape: [0], kind: .int) } @@ -199,19 +251,36 @@ public struct MLXLiteralMacro: ExpressionMacro { let firstShape = children[0].shape if children.dropFirst().contains(where: { $0.shape != firstShape }) { + // MLXArray construction here assumes rectangular nested literals. + // Ragged arrays are rejected early with a macro diagnostic. diagnose( "#mlx does not support ragged nested arrays.", at: Syntax(expr), in: context) throw MacroError() } - let kind = children.dropFirst().reduce(children[0].kind) { - ScalarKind.merge($0, $1.kind) + guard + let kind = children.dropFirst().reduce( + Optional(children[0].kind), + { + partial, next in + guard let partial else { return nil } + return ScalarKind.merge(partial, next.kind) + }) + else { + diagnose( + "#mlx does not support mixing boolean and numeric literals in the same array.", + at: Syntax(expr), + in: context) + throw MacroError() } return ParsedLiteral( flat: children.flatMap(\.flat), shape: [children.count] + firstShape, kind: kind) } + if isBool(expr) { + return ParsedLiteral(flat: [expr], shape: [], kind: .bool) + } if isInteger(expr) { return ParsedLiteral(flat: [expr], shape: [], kind: .int) } @@ -220,16 +289,22 @@ public struct MLXLiteralMacro: ExpressionMacro { } diagnose( - "#mlx only supports integer and floating-point literals.", at: Syntax(expr), in: context + "#mlx only supports boolean, integer, and floating-point literals.", at: Syntax(expr), + in: context ) throw MacroError() } + private static func isBool(_ expr: ExprSyntax) -> Bool { + expr.as(BooleanLiteralExprSyntax.self) != nil + } + private static func isInteger(_ expr: ExprSyntax) -> Bool { if expr.as(IntegerLiteralExprSyntax.self) != nil { return true } if let prefix = expr.as(PrefixOperatorExprSyntax.self) { + // Accept signed integer literals like -3 / +7. return isInteger(prefix.expression) } return false @@ -240,11 +315,44 @@ public struct MLXLiteralMacro: ExpressionMacro { return true } if let prefix = expr.as(PrefixOperatorExprSyntax.self) { + // Accept signed float literals like -3.5 / +1.0e-3. return isFloat(prefix.expression) } return false } + private static func integerLiteralValue(_ expr: ExprSyntax) -> Int? { + if let literal = expr.as(IntegerLiteralExprSyntax.self) { + return parseIntegerToken(literal.literal.text) + } + if let prefix = expr.as(PrefixOperatorExprSyntax.self) { + guard let value = integerLiteralValue(prefix.expression) else { return nil } + switch prefix.operator.text { + case "+": + return value + case "-": + return -value + default: + return nil + } + } + return nil + } + + private static func parseIntegerToken(_ token: String) -> Int? { + let text = String(token.filter { $0 != "_" }) + if text.hasPrefix("0x") || text.hasPrefix("0X") { + return Int(text.dropFirst(2), radix: 16) + } + if text.hasPrefix("0b") || text.hasPrefix("0B") { + return Int(text.dropFirst(2), radix: 2) + } + if text.hasPrefix("0o") || text.hasPrefix("0O") { + return Int(text.dropFirst(2), radix: 8) + } + return Int(text) + } + private static func diagnose( _ message: String, at node: Syntax, diff --git a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift index b657de6e..efb382a7 100644 --- a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift +++ b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift @@ -27,6 +27,14 @@ final class MLXLiteralMacroTests: XCTestCase { ) } + func testExpandsBooleanLiteral() { + assertMacroExpansion( + "#mlx([[true, false], [false, true]])", + expandedSource: "MLXArray([true, false, false, true], [2, 2])", + macros: testMacros + ) + } + func testExpandsWithDtypeCast() { assertMacroExpansion( "#mlx([[1, 2], [3, 4]], dtype: .int16)", @@ -67,6 +75,14 @@ final class MLXLiteralMacroTests: XCTestCase { ) } + func testExpandsBoolDtypeFromZeroOneLiterals() { + assertMacroExpansion( + "#mlx([[0, 1], [1, 0]], dtype: .bool)", + expandedSource: "MLXArray([false, true, true, false], [2, 2])", + macros: testMacros + ) + } + func testFallsBackToAsTypeForDynamicDtypeExpression() { assertMacroExpansion( "#mlx([[1, 2], [3, 4]], dtype: dtypeValue)", @@ -148,6 +164,34 @@ final class MLXLiteralMacroTests: XCTestCase { ) } + func testRejectsBoolDtypeWithOutOfRangeIntegerLiterals() { + assertMacroExpansion( + "#mlx([[0, 2], [1, 0]], dtype: .bool)", + expandedSource: "MLXArray([])", + diagnostics: [ + DiagnosticSpec( + message: "#mlx dtype .bool only supports integer literals 0 or 1.", + line: 1, + column: 11) + ], + macros: testMacros + ) + } + + func testRejectsBoolDtypeWithFloatLiterals() { + assertMacroExpansion( + "#mlx([[0.0, 1.0]], dtype: .bool)", + expandedSource: "MLXArray([])", + diagnostics: [ + DiagnosticSpec( + message: "#mlx dtype .bool only supports true/false literals or integer 0/1.", + line: 1, + column: 27) + ], + macros: testMacros + ) + } + func testRaggedLiteralDiagnostics() { assertMacroExpansion( "#mlx([[1, 2], [3]])", From 7708504ad74dda2c960b7494eb664fec9f633325 Mon Sep 17 00:00:00 2001 From: Robert Sale Date: Thu, 12 Feb 2026 11:38:01 -0800 Subject: [PATCH 5/7] Renamed macro to #MLXArray --- .../Organization/initialization.md | 20 ++++--- Source/MLX/MLXMacros.swift | 14 ++--- Source/MLXMacrosPlugin/MLXLiteralMacro.swift | 18 +++--- .../MLXMacrosTests/MLXLiteralMacroTests.swift | 60 +++++++++---------- 4 files changed, 58 insertions(+), 54 deletions(-) diff --git a/Source/MLX/Documentation.docc/Organization/initialization.md b/Source/MLX/Documentation.docc/Organization/initialization.md index b79b95d3..2eb1b49c 100644 --- a/Source/MLX/Documentation.docc/Organization/initialization.md +++ b/Source/MLX/Documentation.docc/Organization/initialization.md @@ -143,20 +143,21 @@ let v1 = MLXArray(0 ..< 12, [3, 4]) ### Macro Literals -You can also create arrays from nested literals with the `#mlx` expression macro: +You can also create arrays from nested literals with the `#MLXArray` expression macro: ```swift import MLX -let a = #mlx([[1, 2], [3, 4]]) -let b = #mlx([[1, 2], [3, 4]], dtype: .int16) -let c = #mlx([[[0.1, 0.2], [0.3, 0.4]]], dtype: .float16) -let d = #mlx([[true, false], [false, true]]) -let e = #mlx([[0, 1], [1, 0]], dtype: .bool) +let a = #MLXArray([[1, 2], [3, 4]]) +let b = #MLXArray([[1, 2], [3, 4]], dtype: .int16) +let c = #MLXArray([[[0.1, 0.2], [0.3, 0.4]]], dtype: .float16) +let d = #MLXArray([[true, false], [false, true]]) +let e = #MLXArray([[0, 1], [1, 0]], dtype: .bool) ``` This is especially convenient for small constants in model code and tests. -The macro requires rectangular nested arrays and numeric literals. +The macro requires rectangular nested arrays with boolean, integer, or floating-point literals. +Mixing boolean and numeric literals in the same nested literal is not supported. Mixed numeric literals are promoted to floating-point behavior: if any element is a floating-point literal, the entire literal is treated as floating-point. @@ -166,9 +167,12 @@ For dynamic dtype expressions, or dtypes that do not map cleanly to a Swift lite (for example `.float16`, `.bfloat16`, `.complex64`), the macro emits a base array and applies `.asType(...)`. +If `dtype` is an integer type and the literal contains floating-point values, +the macro still allows expansion but emits a warning because conversion may truncate. + ```swift // promoted to floating-point because of 2.5 -let mixed = #mlx([[1, 2.5], [3, 4]]) +let mixed = #MLXArray([[1, 2.5], [3, 4]]) ``` Boolean literals are supported directly (`true` / `false`). For `dtype: .bool`, diff --git a/Source/MLX/MLXMacros.swift b/Source/MLX/MLXMacros.swift index be72cab9..3c879fb9 100644 --- a/Source/MLX/MLXMacros.swift +++ b/Source/MLX/MLXMacros.swift @@ -1,21 +1,21 @@ // Copyright © 2026 Apple Inc. -/// Construct an ``MLXArray`` from a nested numeric literal. +/// Construct an ``MLXArray`` from a nested literal. /// /// Examples: /// /// ```swift -/// let a = #mlx([[1, 2, 3], [4, 5, 6]]) -/// let b = #mlx([[1, 2, 3], [4, 5, 6]], dtype: .int16) -/// let c = #mlx([[0.1, 0.2], [0.3, 0.4]], dtype: .float16) +/// let a = #MLXArray([[1, 2, 3], [4, 5, 6]]) +/// let b = #MLXArray([[1, 2, 3], [4, 5, 6]], dtype: .int16) +/// let c = #MLXArray([[0.1, 0.2], [0.3, 0.4]], dtype: .float16) /// ``` @freestanding(expression) -public macro mlx(_ literal: Any) -> MLXArray = +public macro MLXArray(_ literal: Any) -> MLXArray = #externalMacro( module: "MLXMacrosPlugin", type: "MLXLiteralMacro") -/// Construct an ``MLXArray`` from a nested numeric literal and cast to `dtype`. +/// Construct an ``MLXArray`` from a nested literal and cast to `dtype`. @freestanding(expression) -public macro mlx(_ literal: Any, dtype: DType) -> MLXArray = +public macro MLXArray(_ literal: Any, dtype: DType) -> MLXArray = #externalMacro( module: "MLXMacrosPlugin", type: "MLXLiteralMacro") diff --git a/Source/MLXMacrosPlugin/MLXLiteralMacro.swift b/Source/MLXMacrosPlugin/MLXLiteralMacro.swift index a367ab26..23d6aec1 100644 --- a/Source/MLXMacrosPlugin/MLXLiteralMacro.swift +++ b/Source/MLXMacrosPlugin/MLXLiteralMacro.swift @@ -70,7 +70,7 @@ public struct MLXLiteralMacro: ExpressionMacro { ) throws -> ExprSyntax { let args = Array(node.arguments) guard let literalArg = args.first else { - diagnose("#mlx requires a nested numeric array literal.", at: Syntax(node), in: context) + diagnose("#MLXArray requires a nested numeric array literal.", at: Syntax(node), in: context) return "MLXArray([])" } @@ -80,14 +80,14 @@ public struct MLXLiteralMacro: ExpressionMacro { } else if args.count == 2 { guard args[1].label?.text == "dtype" else { diagnose( - "#mlx second argument must be labeled 'dtype:'.", + "#MLXArray second argument must be labeled 'dtype:'.", at: Syntax(args[1]), in: context) return "MLXArray([])" } dtypeExpr = args[1].expression } else { diagnose( - "#mlx accepts one literal argument and optional dtype:.", at: Syntax(node), + "#MLXArray accepts one literal argument and optional dtype:.", at: Syntax(node), in: context) return "MLXArray([])" } @@ -127,7 +127,7 @@ public struct MLXLiteralMacro: ExpressionMacro { guard let value = integerLiteralValue(element), value == 0 || value == 1 else { diagnose( - "#mlx dtype .bool only supports integer literals 0 or 1.", + "#MLXArray dtype .bool only supports integer literals 0 or 1.", at: Syntax(element), in: context) return "MLXArray([])" @@ -138,7 +138,7 @@ public struct MLXLiteralMacro: ExpressionMacro { return "MLXArray([\(raw: boolSource)], [\(raw: shapeSource)])" case .float: diagnose( - "#mlx dtype .bool only supports true/false literals or integer 0/1.", + "#MLXArray dtype .bool only supports true/false literals or integer 0/1.", at: Syntax(dtypeExpr), in: context) return "MLXArray([])" @@ -149,7 +149,7 @@ public struct MLXLiteralMacro: ExpressionMacro { // since runtime conversion may truncate. if isIntegerDType(knownDType), let floatExpr = parsed.flat.first(where: isFloat) { diagnose( - "#mlx integer dtype with floating-point literal(s) may truncate values during conversion.", + "#MLXArray integer dtype with floating-point literal(s) may truncate values during conversion.", at: Syntax(floatExpr), severity: .warning, in: context) @@ -254,7 +254,7 @@ public struct MLXLiteralMacro: ExpressionMacro { // MLXArray construction here assumes rectangular nested literals. // Ragged arrays are rejected early with a macro diagnostic. diagnose( - "#mlx does not support ragged nested arrays.", at: Syntax(expr), in: context) + "#MLXArray does not support ragged nested arrays.", at: Syntax(expr), in: context) throw MacroError() } @@ -268,7 +268,7 @@ public struct MLXLiteralMacro: ExpressionMacro { }) else { diagnose( - "#mlx does not support mixing boolean and numeric literals in the same array.", + "#MLXArray does not support mixing boolean and numeric literals in the same array.", at: Syntax(expr), in: context) throw MacroError() @@ -289,7 +289,7 @@ public struct MLXLiteralMacro: ExpressionMacro { } diagnose( - "#mlx only supports boolean, integer, and floating-point literals.", at: Syntax(expr), + "#MLXArray only supports boolean, integer, and floating-point literals.", at: Syntax(expr), in: context ) throw MacroError() diff --git a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift index efb382a7..b587d521 100644 --- a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift +++ b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift @@ -7,13 +7,13 @@ import XCTest @testable import MLXMacrosPlugin private let testMacros: [String: Macro.Type] = [ - "mlx": MLXLiteralMacro.self + "MLXArray": MLXLiteralMacro.self, ] final class MLXLiteralMacroTests: XCTestCase { func testExpandsIntegerLiteral() { assertMacroExpansion( - "#mlx([[1, 2], [3, 4]])", + "#MLXArray([[1, 2], [3, 4]])", expandedSource: "MLXArray([1, 2, 3, 4], [2, 2])", macros: testMacros ) @@ -21,7 +21,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsFloatLiteral() { assertMacroExpansion( - "#mlx([[0.1, 0.2], [0.3, 0.4]])", + "#MLXArray([[0.1, 0.2], [0.3, 0.4]])", expandedSource: "MLXArray(converting: [0.1, 0.2, 0.3, 0.4], [2, 2])", macros: testMacros ) @@ -29,7 +29,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsBooleanLiteral() { assertMacroExpansion( - "#mlx([[true, false], [false, true]])", + "#MLXArray([[true, false], [false, true]])", expandedSource: "MLXArray([true, false, false, true], [2, 2])", macros: testMacros ) @@ -37,7 +37,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsWithDtypeCast() { assertMacroExpansion( - "#mlx([[1, 2], [3, 4]], dtype: .int16)", + "#MLXArray([[1, 2], [3, 4]], dtype: .int16)", expandedSource: "MLXArray([Int16(1), Int16(2), Int16(3), Int16(4)], [2, 2])", macros: testMacros ) @@ -45,7 +45,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsIntegerLiteralWithInt64Dtype() { assertMacroExpansion( - "#mlx([[1, 2], [3, 4]], dtype: .int64)", + "#MLXArray([[1, 2], [3, 4]], dtype: .int64)", expandedSource: "MLXArray([Int64(1), Int64(2), Int64(3), Int64(4)], [2, 2])", macros: testMacros ) @@ -53,7 +53,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsIntegerLiteralWithUInt8Dtype() { assertMacroExpansion( - "#mlx([[1, 2], [3, 4]], dtype: .uint8)", + "#MLXArray([[1, 2], [3, 4]], dtype: .uint8)", expandedSource: "MLXArray([UInt8(1), UInt8(2), UInt8(3), UInt8(4)], [2, 2])", macros: testMacros ) @@ -61,7 +61,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsIntegerLiteralWithFloat32Dtype() { assertMacroExpansion( - "#mlx([[1, 2], [3, 4]], dtype: .float32)", + "#MLXArray([[1, 2], [3, 4]], dtype: .float32)", expandedSource: "MLXArray([Float(1), Float(2), Float(3), Float(4)], [2, 2])", macros: testMacros ) @@ -69,7 +69,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testFallsBackToAsTypeForFloat64Dtype() { assertMacroExpansion( - "#mlx([[1.0, 2.0], [3.0, 4.0]], dtype: .float64)", + "#MLXArray([[1.0, 2.0], [3.0, 4.0]], dtype: .float64)", expandedSource: "MLXArray(converting: [1.0, 2.0, 3.0, 4.0], [2, 2]).asType(.float64)", macros: testMacros ) @@ -77,7 +77,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsBoolDtypeFromZeroOneLiterals() { assertMacroExpansion( - "#mlx([[0, 1], [1, 0]], dtype: .bool)", + "#MLXArray([[0, 1], [1, 0]], dtype: .bool)", expandedSource: "MLXArray([false, true, true, false], [2, 2])", macros: testMacros ) @@ -85,7 +85,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testFallsBackToAsTypeForDynamicDtypeExpression() { assertMacroExpansion( - "#mlx([[1, 2], [3, 4]], dtype: dtypeValue)", + "#MLXArray([[1, 2], [3, 4]], dtype: dtypeValue)", expandedSource: "MLXArray([1, 2, 3, 4], [2, 2]).asType(dtypeValue)", macros: testMacros ) @@ -93,7 +93,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsThreeDimensionalIntegerLiteral() { assertMacroExpansion( - "#mlx([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])", + "#MLXArray([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])", expandedSource: "MLXArray([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2])", macros: testMacros ) @@ -101,7 +101,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsFourDimensionalFloatLiteral() { assertMacroExpansion( - "#mlx([[[[0.1, 0.2]], [[0.3, 0.4]]], [[[0.5, 0.6]], [[0.7, 0.8]]]])", + "#MLXArray([[[[0.1, 0.2]], [[0.3, 0.4]]], [[[0.5, 0.6]], [[0.7, 0.8]]]])", expandedSource: "MLXArray(converting: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], [2, 2, 1, 2])", macros: testMacros @@ -110,7 +110,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsMixedIntegerFloatLiteralAsFloat() { assertMacroExpansion( - "#mlx([[1, 2.5], [3, 4.5]])", + "#MLXArray([[1, 2.5], [3, 4.5]])", expandedSource: "MLXArray(converting: [1, 2.5, 3, 4.5], [2, 2])", macros: testMacros ) @@ -118,7 +118,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsSingleFloatElementAsFloatLiteral() { assertMacroExpansion( - "#mlx([[1, 2], [3, 4.5]])", + "#MLXArray([[1, 2], [3, 4.5]])", expandedSource: "MLXArray(converting: [1, 2, 3, 4.5], [2, 2])", macros: testMacros ) @@ -126,7 +126,7 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsDeepLiteralWithFloat16Dtype() { assertMacroExpansion( - "#mlx([[[1, 2], [3, 4]]], dtype: .float16)", + "#MLXArray([[[1, 2], [3, 4]]], dtype: .float16)", expandedSource: "MLXArray([1, 2, 3, 4], [1, 2, 2]).asType(.float16)", macros: testMacros ) @@ -134,14 +134,14 @@ final class MLXLiteralMacroTests: XCTestCase { func testExpandsMixedLiteralWithInt8Dtype() { assertMacroExpansion( - "#mlx([[1.25, 2], [3.5, 4]], dtype: .int8)", + "#MLXArray([[1.25, 2], [3.5, 4]], dtype: .int8)", expandedSource: "MLXArray(converting: [1.25, 2, 3.5, 4], [2, 2]).asType(.int8)", diagnostics: [ DiagnosticSpec( message: - "#mlx integer dtype with floating-point literal(s) may truncate values during conversion.", + "#MLXArray integer dtype with floating-point literal(s) may truncate values during conversion.", line: 1, - column: 8, + column: 13, severity: .warning) ], macros: testMacros @@ -150,14 +150,14 @@ final class MLXLiteralMacroTests: XCTestCase { func testWarnsOnIntegerDtypeWithFloatLiteral() { assertMacroExpansion( - "#mlx([[1, 2.5], [3, 4]], dtype: .int16)", + "#MLXArray([[1, 2.5], [3, 4]], dtype: .int16)", expandedSource: "MLXArray(converting: [1, 2.5, 3, 4], [2, 2]).asType(.int16)", diagnostics: [ DiagnosticSpec( message: - "#mlx integer dtype with floating-point literal(s) may truncate values during conversion.", + "#MLXArray integer dtype with floating-point literal(s) may truncate values during conversion.", line: 1, - column: 11, + column: 16, severity: .warning) ], macros: testMacros @@ -166,13 +166,13 @@ final class MLXLiteralMacroTests: XCTestCase { func testRejectsBoolDtypeWithOutOfRangeIntegerLiterals() { assertMacroExpansion( - "#mlx([[0, 2], [1, 0]], dtype: .bool)", + "#MLXArray([[0, 2], [1, 0]], dtype: .bool)", expandedSource: "MLXArray([])", diagnostics: [ DiagnosticSpec( - message: "#mlx dtype .bool only supports integer literals 0 or 1.", + message: "#MLXArray dtype .bool only supports integer literals 0 or 1.", line: 1, - column: 11) + column: 16) ], macros: testMacros ) @@ -180,13 +180,13 @@ final class MLXLiteralMacroTests: XCTestCase { func testRejectsBoolDtypeWithFloatLiterals() { assertMacroExpansion( - "#mlx([[0.0, 1.0]], dtype: .bool)", + "#MLXArray([[0.0, 1.0]], dtype: .bool)", expandedSource: "MLXArray([])", diagnostics: [ DiagnosticSpec( - message: "#mlx dtype .bool only supports true/false literals or integer 0/1.", + message: "#MLXArray dtype .bool only supports true/false literals or integer 0/1.", line: 1, - column: 27) + column: 32) ], macros: testMacros ) @@ -194,11 +194,11 @@ final class MLXLiteralMacroTests: XCTestCase { func testRaggedLiteralDiagnostics() { assertMacroExpansion( - "#mlx([[1, 2], [3]])", + "#MLXArray([[1, 2], [3]])", expandedSource: "MLXArray([])", diagnostics: [ DiagnosticSpec( - message: "#mlx does not support ragged nested arrays.", line: 1, column: 6) + message: "#MLXArray does not support ragged nested arrays.", line: 1, column: 11) ], macros: testMacros ) From 9e64c68546936ce9ff8d2d4431708cef1e7d537f Mon Sep 17 00:00:00 2001 From: Robert Sale Date: Mon, 18 May 2026 09:13:41 -0700 Subject: [PATCH 6/7] Ran code formatter --- Source/MLXMacrosPlugin/MLXLiteralMacro.swift | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Source/MLXMacrosPlugin/MLXLiteralMacro.swift b/Source/MLXMacrosPlugin/MLXLiteralMacro.swift index 23d6aec1..97ba801b 100644 --- a/Source/MLXMacrosPlugin/MLXLiteralMacro.swift +++ b/Source/MLXMacrosPlugin/MLXLiteralMacro.swift @@ -70,7 +70,8 @@ public struct MLXLiteralMacro: ExpressionMacro { ) throws -> ExprSyntax { let args = Array(node.arguments) guard let literalArg = args.first else { - diagnose("#MLXArray requires a nested numeric array literal.", at: Syntax(node), in: context) + diagnose( + "#MLXArray requires a nested numeric array literal.", at: Syntax(node), in: context) return "MLXArray([])" } @@ -254,7 +255,8 @@ public struct MLXLiteralMacro: ExpressionMacro { // MLXArray construction here assumes rectangular nested literals. // Ragged arrays are rejected early with a macro diagnostic. diagnose( - "#MLXArray does not support ragged nested arrays.", at: Syntax(expr), in: context) + "#MLXArray does not support ragged nested arrays.", at: Syntax(expr), + in: context) throw MacroError() } @@ -289,7 +291,8 @@ public struct MLXLiteralMacro: ExpressionMacro { } diagnose( - "#MLXArray only supports boolean, integer, and floating-point literals.", at: Syntax(expr), + "#MLXArray only supports boolean, integer, and floating-point literals.", + at: Syntax(expr), in: context ) throw MacroError() From 427beee57e31af95ad32efe8a2be04f546eac61a Mon Sep 17 00:00:00 2001 From: Robert Sale Date: Mon, 18 May 2026 09:17:44 -0700 Subject: [PATCH 7/7] Ran code formatter in Tests --- Tests/MLXMacrosTests/MLXLiteralMacroTests.swift | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift index b587d521..51c13080 100644 --- a/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift +++ b/Tests/MLXMacrosTests/MLXLiteralMacroTests.swift @@ -7,7 +7,7 @@ import XCTest @testable import MLXMacrosPlugin private let testMacros: [String: Macro.Type] = [ - "MLXArray": MLXLiteralMacro.self, + "MLXArray": MLXLiteralMacro.self ] final class MLXLiteralMacroTests: XCTestCase { @@ -184,7 +184,8 @@ final class MLXLiteralMacroTests: XCTestCase { expandedSource: "MLXArray([])", diagnostics: [ DiagnosticSpec( - message: "#MLXArray dtype .bool only supports true/false literals or integer 0/1.", + message: + "#MLXArray dtype .bool only supports true/false literals or integer 0/1.", line: 1, column: 32) ], @@ -198,7 +199,8 @@ final class MLXLiteralMacroTests: XCTestCase { expandedSource: "MLXArray([])", diagnostics: [ DiagnosticSpec( - message: "#MLXArray does not support ragged nested arrays.", line: 1, column: 11) + message: "#MLXArray does not support ragged nested arrays.", line: 1, column: 11 + ) ], macros: testMacros )