Current behavior
In dependency/graph.py#L276, pruning groups are always computed based on output channels when calling DependencyGraph.compute_all_groups, which seems like an arbitrary choice.
Issue
For many newer architectures (e.g., LLMs, MLP blocks), pruning is often expressed more naturally in terms of input channels; e.g., pruning in_features of a down_proj matrix.
Currently, compute_all_groups does not provide a way to group dependencies based on input channels.
Why this hasn’t been a blocker
One can still compute pruning groups for the layers of interest "manually" using DependencyGraph.get_pruning_group if ```input_channels`` need to be pruned.
- However, this requires providing the right pruning function.
In many cases, pruning out_channels implicitly prunes corresponding in_channels due to coupling across dependencies.
- Having an option to directly have groups based on in_channels allows for better downstream applications of the group. e.g., finding the weights of the module whose input_channels are being pruned to compute new saliencies.
Proposal
I’d like to add an option to compute_all_groups (e.g., a boolean flag) that lets users choose whether groups are computed with respect to out_channels (default, current behavior) or in_channels.
This would make the API more flexible for different pruning strategies.
The choice of computing all groups only based on the output channel seems to be an arbitrary decision, and the API should support what type of groups to compute.
I’m happy to open a PR for this, but wanted to get the maintainers’ feedback first.
Current behavior
In dependency/graph.py#L276, pruning groups are always computed based on output channels when calling
DependencyGraph.compute_all_groups, which seems like an arbitrary choice.Issue
For many newer architectures (e.g., LLMs, MLP blocks), pruning is often expressed more naturally in terms of input channels; e.g., pruning
in_featuresof adown_projmatrix.Currently,
compute_all_groupsdoes not provide a way to group dependencies based on input channels.Why this hasn’t been a blocker
Proposal
I’d like to add an option to compute_all_groups (e.g., a boolean flag) that lets users choose whether groups are computed with respect to
out_channels(default, current behavior) orin_channels.This would make the API more flexible for different pruning strategies.
The choice of computing all groups only based on the output channel seems to be an arbitrary decision, and the API should support what type of groups to compute.
I’m happy to open a PR for this, but wanted to get the maintainers’ feedback first.