Model inspection and summary tools for MLX neural networks on Apple Silicon.
Inspired from and similar to torchsummary for PyTorch, but designed specifically for MLX's module system.
- π Multiple output formats: Table, Tree, JSON, Markdown, Minimal
- π Detailed inspection: Layer paths, parameter counts, shapes
- π― Filtering: Find layers by type, name pattern, or parameter count
- π Statistics: Aggregate stats by layer type
- π₯οΈ CLI support: Use from command line or Python
- π§ Freeze-aware: Track trainable vs frozen parameters
pip install mlxsummaryOr install from source:
git clone https://github.com/dhruvshr/mlxsummary.git
cd mlxsummary
pip install -e .import mlx.nn as nn
from mlxsummary import summary
# Create a model
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
)
# Print summary
summary(model)Output:
=============================================================================================================================
Model Summary: Sequential
=============================================================================================================================
Layer Type Params Details
-----------------------------------------------------------------------------------------------------------------------------
layers.0 Linear 200,960 (784 β 256)
layers.1 ReLU 0
layers.2 Dropout 0
layers.3 Linear 32,896 (256 β 128)
layers.4 ReLU 0
layers.5 Linear 1,290 (128 β 10)
-----------------------------------------------------------------------------------------------------------------------------
Total Parameters: 235,146
=============================================================================================================================
summary(model, format="table")summary(model, format="tree")π¦ Sequential (235,146 params)
β βββ 0: Linear (784 β 256) [200,960]
β βββ 1: ReLU [0]
β βββ 2: Dropout [0]
β βββ 3: Linear (256 β 128) [32,896]
β βββ 4: ReLU [0]
β βββ 5: Linear (128 β 10) [1,290]
data = summary(model, format="json", print_output=False)summary(model, format="markdown")summary(model, format="minimal")
# Output: Sequential: 235,146 params (6 layers)For detailed programmatic access:
from mlxsummary import inspect
inspector = inspect(model)
# Get all layers
layers = inspector.get_layers()
for layer in layers:
print(f"{layer.path}: {layer.total_params:,} params")
# Get statistics
stats = inspector.get_stats()
print(f"Total: {stats.total_params:,}")
print(f"Layer types: {stats.layer_type_counts}")
# Find specific layers
linear_layers = inspector.find_layers(layer_type=nn.Linear)
attention = inspector.find_layers(name_pattern="attention")
large_layers = inspector.find_layers(min_params=10000)from mlxsummary import count_params, get_layers, get_stats, to_dict
# Count parameters
total = count_params(model)
trainable = count_params(model, trainable_only=True)
# Get layers
layers = get_layers(model)
linear_layers = get_layers(model, layer_type=nn.Linear)
# Get stats
stats = get_stats(model)
# Export to dict
data = to_dict(model)summary(
model,
format="table", # Output format
show_shapes=True, # Show layer dimensions
show_trainable=True, # Show trainable params column
show_frozen=False, # Show frozen params column
max_depth=None, # Limit layer depth
max_rows=None, # Limit output rows
include_zero_param=True, # Include zero-param layers
width=100, # Output width
print_output=True, # Print vs return only
)# Summarize a model from a file
mlxsummary model.py
# Different formats
mlxsummary model.py --format tree
mlxsummary model.py --format json -o model.json
mlxsummary model.py --format markdown
# Options
mlxsummary model.py --max-depth 2
mlxsummary model.py --hide-zero
mlxsummary model.py --no-shapes
# Demo mode
mlxsummary --demoYour model file should define a model variable or get_model() function:
# model.py
import mlx.nn as nn
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)| Class | Description |
|---|---|
MLXInspector |
Main inspector for detailed model analysis |
LayerInfo |
Information about a single layer |
ModelStats |
Aggregate statistics about a model |
FormatterOptions |
Options for output formatting |
OutputFormat |
Enum of available output formats |
| Function | Description |
|---|---|
summary(model, ...) |
Generate and print a model summary |
inspect(model) |
Create an inspector instance |
count_params(model) |
Count model parameters |
get_layers(model) |
Get list of layer information |
get_stats(model) |
Get aggregate statistics |
to_dict(model) |
Export model info to dictionary |
tree(model) |
Shortcut for tree format |
table(model) |
Shortcut for table format |
- macOS with Apple Silicon (M1/M2/M3/M4)
- Python 3.9+
- MLX 0.1.0+
MIT License - see LICENSE file for details.
Contributions are welcome! Please open an issue or pull request.
- MLX Documentation
- MLX GitHub
- torchsummary - Similar tool for PyTorch