This document outlines the key differences between PyTorch and MLX Swift that are critical for porting neural network models. These differences span tensor formats, weight layouts, APIs, and architectural patterns.
| Framework | Format | Description |
|---|---|---|
| PyTorch | (B, C, T) |
Batch, Channels, Time |
| MLX Swift | (B, T, C) |
Batch, Time, Channels |
Impact: Requires transposes at module boundaries.
// PyTorch format to MLX Swift format
let mlxInput = pytorchInput.transposed(0, 2, 1) // (B, C, T) → (B, T, C)
// MLX Swift format to PyTorch format
let pytorchOutput = mlxOutput.transposed(0, 2, 1) // (B, T, C) → (B, C, T)| Framework | Format |
|---|---|
| PyTorch | (B, C, H, W) |
| MLX Swift | (B, H, W, C) |
| Framework | Shape |
|---|---|
| PyTorch | (Out_Channels, In_Channels, Kernel_Size) |
| MLX Swift | (Out_Channels, Kernel_Size, In_Channels) |
Conversion:
# Python conversion script
mlx_weight = pytorch_weight.transpose(0, 2, 1)// Swift conversion (if needed at runtime)
let mlxWeight = pytorchWeight.transposed(axes: [0, 2, 1])| Framework | Shape |
|---|---|
| PyTorch | (In_Channels, Out_Channels, Kernel_Size) |
| MLX Swift | (Out_Channels, Kernel_Size, In_Channels) |
Conversion:
mlx_weight = pytorch_weight.transpose(1, 2, 0)| Framework | Shape |
|---|---|
| PyTorch | (Out_C, In_C, H, W) |
| MLX Swift | (Out_C, H, W, In_C) |
Conversion:
mlx_weight = pytorch_weight.transpose(0, 2, 3, 1)Both frameworks use the same shape - no conversion needed:
- Linear:
(Out_Features, In_Features) - Embedding:
(Num_Embeddings, Embedding_Dim)
# Built-in support
torch.nn.utils.weight_norm(layer)
# Stores as weight_g and weight_v parameters// No built-in weight_norm - must fuse during conversion
// weight = weight_g * (weight_v / ||weight_v||)
if let weightG = params["weight_g"], let weightV = params["weight_v"] {
let vSqr = weightV * weightV
let vNorm = sqrt(vSqr.sum(axes: [1, 2], keepDims: true) + 1e-12)
let fusedWeight = weightG * (weightV / vNorm)
params["weight"] = fusedWeight
}class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv1d(in_ch, out_ch, kernel)
def forward(self, x):
return self.conv(x)class MyModule: Module {
let conv: MLXNN.Conv1d
init(inCh: Int, outCh: Int, kernel: Int) {
self.conv = MLXNN.Conv1d(inputChannels: inCh, outputChannels: outCh, kernelSize: kernel)
super.init() // MUST call after property initialization
}
func callAsFunction(_ x: MLXArray) -> MLXArray {
return conv(x)
}
}Key Differences:
- Swift uses
letproperties (immutable after init) super.init()called AFTER property assignment in Swift- Swift uses
callAsFunctioninstead offorward
# Save
torch.save(model.state_dict(), "model.pth")
# Load
state_dict = torch.load("model.pth")
model.load_state_dict(state_dict, strict=False)
# Access parameters
for name, param in model.named_parameters():
print(name, param.shape)// Load from safetensors
let weights = try MLX.loadArrays(url: url)
// Update model with weights
model.update(parameters: ModuleParameters.unflattened(weights))
// Set eval mode
model.train(false)
// Access parameters
let params = model.parameters()# PyTorch
x[:, :, ::-1] # Reverse last dimension
x[:, :, a:b] # Range slice// MLX Swift
x[0..., 0..., .stride(by: -1)] // Reverse last dimension
x[0..., 0..., a..<b] // Range slice# PyTorch
F.pad(x, (left, right)) # 1D padding
F.pad(x, (l, r, t, b)) # 2D padding// MLX Swift - must specify all dimensions
MLX.padded(x, widths: [IntOrPair((0, 0)), IntOrPair((left, right)), IntOrPair((0, 0))])# PyTorch
torch.cat([x0, x1], dim=2)// MLX Swift
MLX.concatenated([x0, x1], axis: 2)F.leaky_relu(x, negative_slope=0.1)
F.gelu(x)
torch.sigmoid(x)
torch.tanh(x)leakyRelu(x, negativeSlope: 0.1)
gelu(x, approximate: .none) // Note: .none for exact match
sigmoid(x)
tanh(x)# Running statistics
bn.running_mean
bn.running_var
bn.num_batches_tracked// Camel case naming
bn.runningMean
bn.runningVar
// num_batches_tracked often skipped in inferenceKey Mapping in Conversion:
key = key.replacingOccurrences(of: ".running_mean", with: ".runningMean")
key = key.replacingOccurrences(of: ".running_var", with: ".runningVar")# Some models use gamma/beta
ln.gamma # Scale
ln.beta # Biasln.weight // Scale
ln.bias // BiasKey Mapping:
key = key.replacingOccurrences(of: ".gamma", with: ".weight")
key = key.replacingOccurrences(of: ".beta", with: ".bias")# ModuleList automatically registers
self.flows = nn.ModuleList([
ResidualCouplingLayer(...) for _ in range(n_flows)
])
# Access: self.flows[0], self.flows[1], ...
# Weight keys: flow.flows.0.*, flow.flows.1.*, ...// Must use named properties for weight loading
let flow_0: ResidualCouplingLayer
let flow_1: ResidualCouplingLayer
let flow_2: ResidualCouplingLayer
let flow_3: ResidualCouplingLayer
// Arrays don't work for weight loading!
// var flows: [ResidualCouplingLayer] = [] // WRONGdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
x = x.to(device)// Automatic - no explicit device management needed
// MLX uses unified memory on Apple Silicon
// Optionally set default device:
MLX.Device.setDefault(device: Device.gpu)with torch.no_grad():
output = model(x)// No gradients by default in inference mode
// Just call the model directly
let output = model(x)
// Ensure computation is complete
MLX.eval(output)torch.randn(shape)
torch.normal(mean, std, size)MLXRandom.normal(shape)
MLXRandom.normal(shape, mean: 0.0, std: 1.0)| Feature | PyTorch | MLX Swift |
|---|---|---|
| Conv1d data | (B, C, T) | (B, T, C) |
| Conv1d weight | (Out, In, K) | (Out, K, In) |
| Super init | Before properties | After properties |
| Forward method | forward(self, x) |
callAsFunction(_ x:) |
| Param access | state_dict() |
parameters() |
| Weight loading | load_state_dict() |
update(parameters:) |
| Softmax axis | dim=-1 |
axis: -1 |
| Device | Explicit .to(device) |
Automatic |
| No gradients | torch.no_grad() |
Default behavior |
| ModuleList | Supported | Use named properties |
| Weight norm | Built-in | Fuse manually |
- Forgetting weight transposition - Conv weights have different layouts
- Using arrays instead of named properties - Weights won't load
- Wrong super.init() order - Swift requires properties initialized first
- Missing gamma/beta → weight/bias mapping - LayerNorm will use defaults
- Not fusing weight normalization - Model will have wrong weights
- Assuming same padding API - MLX requires explicit pad_width for all dims
- Mixing tensor formats - Keep track of (B,C,T) vs (B,T,C) throughout
- Transpose all Conv1d weights:
(Out, In, K)→(Out, K, In) - Transpose all ConvTranspose1d weights:
(In, Out, K)→(Out, K, In) - Fuse weight_g/weight_v into single weight tensors
- Map running_mean → runningMean, running_var → runningVar
- Map gamma → weight, beta → bias for LayerNorm
- Convert ModuleList to named properties
- Add tensor format transposes at module boundaries
- Set model to eval mode:
model.train(false) - Call
MLX.eval()after operations to ensure completion