diff --git a/README.md b/README.md index b4d9121..f9d9e1d 100644 --- a/README.md +++ b/README.md @@ -6,15 +6,16 @@

-# SoftTorch +# Soft differentiable programming in PyTorch [![PyPI version](https://img.shields.io/pypi/v/softtorch)](https://pypi.org/project/softtorch/) [![Python version](https://img.shields.io/pypi/pyversions/softtorch)](https://pypi.org/project/softtorch/) [![License](https://img.shields.io/pypi/l/softtorch)](https://github.com/a-paulus/softtorch/blob/main/LICENSE) +[![arXiv paper](https://img.shields.io/badge/arXiv-paper-salmon)](https://arxiv.org/abs/2603.08824) Looking for JAX? See [SoftJAX](https://github.com/a-paulus/softjax). -## In a nutshell +## What is SoftTorch? SoftTorch provides soft differentiable drop-in replacements for traditionally non-differentiable functions in [PyTorch](https://pytorch.org), including @@ -28,8 +29,7 @@ All operators offer multiple modes (controlling smoothness or boundedness of the All operators also support straight-through estimation, using the non-differentiable function in the forward pass and the soft relaxation in the backward pass. -SoftTorch functions are drop-in replacements for their non-differentiable PyTorch counterparts. -Special care is needed for functions operating on indices, as we relax discrete indices into distributions over indices, which modifies the shape of returned/accepted values. +*Note, while SoftTorch is designed to provide direct drop-in replacements for PyTorch's operators, soft axis-wise operators return a probability distribution over indices (instead of an index), effectively changing the shape of the function's output.* ## Installation @@ -44,351 +44,152 @@ pip install softtorch Available at https://a-paulus.github.io/softtorch/. -## Quick example -```python -import torch -import softtorch as st - -x = torch.tensor([-0.2, -1.0, 0.3, 1.0]) - -# Elementwise functions -print("\nTorch absolute:", torch.abs(x)) -print("SoftTorch absolute (hard mode):", st.abs(x, mode="hard")) -print("SoftTorch absolute (soft mode):", st.abs(x)) - -print("\nTorch clamp:", torch.clamp(x, -0.5, 0.5)) -print("SoftTorch clamp (hard mode):", st.clamp(x, -0.5, 0.5, mode="hard")) -print("SoftTorch clamp (soft mode):", st.clamp(x, -0.5, 0.5)) - -print("\nTorch heaviside:", torch.heaviside(x, torch.tensor(0.5))) -print("SoftTorch heaviside (hard mode):", st.heaviside(x, mode="hard")) -print("SoftTorch heaviside (soft mode):", st.heaviside(x)) - -print("\nTorch ReLU:", torch.nn.functional.relu(x)) -print("SoftTorch ReLU (hard mode):", st.relu(x, mode="hard")) -print("SoftTorch ReLU (soft mode):", st.relu(x)) - -print("\nTorch round:", torch.round(x)) -print("SoftTorch round (hard mode):", st.round(x, mode="hard")) -print("SoftTorch round (soft mode):", st.round(x)) - -print("\nTorch sign:", torch.sign(x)) -print("SoftTorch sign (hard mode):", st.sign(x, mode="hard")) -print("SoftTorch sign (soft mode):", st.sign(x)) -``` -``` -Torch absolute: tensor([0.2000, 1.0000, 0.3000, 1.0000]) -SoftTorch absolute (hard mode): tensor([0.2000, 1.0000, 0.3000, 1.0000]) -SoftTorch absolute (soft mode): tensor([0.1523, 0.9999, 0.2715, 0.9999]) - -Torch clamp: tensor([-0.2000, -0.5000, 0.3000, 0.5000]) -SoftTorch clamp (hard mode): tensor([-0.2000, -0.5000, 0.3000, 0.5000]) -SoftTorch clamp (soft mode): tensor([-0.1952, -0.4993, 0.2873, 0.4993]) - -Torch heaviside: tensor([0., 0., 1., 1.]) -SoftTorch heaviside (hard mode): tensor([0., 0., 1., 1.]) -SoftTorch heaviside (soft mode): tensor([0.1192, 0.0000, 0.9526, 1.0000]) - -Torch ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000]) -SoftTorch ReLU (hard mode): tensor([0.0000, 0.0000, 0.3000, 1.0000]) -SoftTorch ReLU (soft mode): tensor([0.0127, 0.0000, 0.3049, 1.0000]) - -Torch round: tensor([-0., -1., 0., 1.]) -SoftTorch round (hard mode): tensor([-0., -1., 0., 1.]) -SoftTorch round (soft mode): tensor([-0.0465, -1.0000, 0.1189, 1.0000]) - -Torch sign: tensor([-1., -1., 1., 1.]) -SoftTorch sign (hard mode): tensor([-1., -1., 1., 1.]) -SoftTorch sign (soft mode): tensor([-0.7616, -0.9999, 0.9051, 0.9999]) -``` +## Quick examples +**Robust median regression:** +Minimize the median absolute residual to be robust to outliers. ```python -# Tensor-valued operators -print("\nTorch max:", torch.max(x)) -print("SoftTorch max (hard mode):", st.max(x, mode="hard")) -print("SoftTorch max (soft mode):", st.max(x)) - -print("\nTorch min:", torch.min(x)) -print("SoftTorch min (hard mode):", st.min(x, mode="hard")) -print("SoftTorch min (soft mode):", st.min(x)) - -print("\nTorch sort:", torch.sort(x).values) -print("SoftTorch sort (hard mode):", st.sort(x, mode="hard").values) -print("SoftTorch sort (soft mode):", st.sort(x).values) - -print("\nTorch quantile:", torch.quantile(x, q=0.2)) -print("SoftTorch quantile (hard mode):", st.quantile(x, q=0.2, mode="hard")) -print("SoftTorch quantile (soft mode):", st.quantile(x, q=0.2)) - -print("\nTorch median:", torch.median(x)) -print("SoftTorch median (hard mode):", st.median(x, mode="hard")) -print("SoftTorch median (soft mode):", st.median(x)) - -print("\nTorch topk:", torch.topk(x, k=3).values) -print("SoftTorch topk (hard mode):", st.topk(x, k=3, mode="hard").values) -print("SoftTorch topk (soft mode):", st.topk(x, k=3).values) - -print("\nTorch rank:", torch.argsort(torch.argsort(x))) -print("SoftTorch rank (hard mode):", st.rank(x, mode="hard", descending=False)) -print("SoftTorch rank (soft mode):", st.rank(x, descending=False)) -``` -``` -Torch max: tensor(1.) -SoftTorch max (hard mode): tensor(1.) -SoftTorch max (soft mode): tensor(0.8874) - -Torch min: tensor(-1.) -SoftTorch min (hard mode): tensor(-1.) -SoftTorch min (soft mode): tensor(-0.8996) - -Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (hard mode): tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (soft mode): tensor([-0.8792, -0.1641, 0.2767, 0.8738]) +import torch, softtorch as st -Torch quantile: tensor(-0.5200) -SoftTorch quantile (hard mode): tensor(-0.5200) -SoftTorch quantile (soft mode): tensor(-0.4501) +torch.manual_seed(0) +X = torch.randn(20, 3) +w_true = torch.tensor([1.0, -2.0, 0.5]) +y = X @ w_true +y[0] = 1e6 # inject outlier -Torch median: tensor(-0.2000) -SoftTorch median (hard mode): tensor(-0.2000) -SoftTorch median (soft mode): tensor(-0.1641) +def median_regression_loss(w, X, y, mode="smooth"): + residuals = y - X @ w + return st.median(st.abs(residuals, mode=mode), mode=mode) -Torch topk: tensor([ 1.0000, 0.3000, -0.2000]) -SoftTorch topk (hard mode): tensor([ 1.0000, 0.3000, -0.2000]) -SoftTorch topk (soft mode): tensor([ 0.8738, 0.2767, -0.1641]) +w = torch.zeros(3, requires_grad=True) +hard_loss = median_regression_loss(w, X, y, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, w)[0]) +soft_loss = median_regression_loss(w, X, y, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, w)[0]) -Torch rank: tensor([1, 0, 2, 3]) -SoftTorch rank (hard mode): tensor([2., 1., 3., 4.]) -SoftTorch rank (soft mode): tensor([1.9950, 1.0548, 3.0239, 3.9228]) -``` - -```python -# Sort: sweep over methods -print("\nTorch sort:", torch.sort(x).values) -print("SoftTorch sort (softsort):", st.sort(x, method="softsort", softness=0.1).values) -print("SoftTorch sort (neuralsort):", st.sort(x, method="neuralsort", softness=0.1).values) -print("SoftTorch sort (fast_soft_sort):", st.sort(x, method="fast_soft_sort", softness=2.0).values) -print("SoftTorch sort (ot):", st.sort(x, method="ot", softness=0.1).values) -print("SoftTorch sort (sorting_network):", st.sort(x, method="sorting_network", softness=0.1).values) - -# Sort: sweep over modes -print("\nTorch sort:", torch.sort(x).values) -for mode in ["hard", "smooth", "c0", "c1", "c2"]: - print(f"SoftTorch sort ({mode}):", st.sort(x, softness=0.5, mode=mode).values) +w = torch.zeros(3) +for _ in range(50): + w.requires_grad_(True) + loss = median_regression_loss(w, X, y) + g = torch.autograd.grad(loss, w)[0] + w = (w - 0.1 * g).detach() +print("Learned w:", w, " (true:", w_true, ")") ``` ``` -Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (softsort): tensor([-0.8996, -0.1705, 0.2847, 0.8874]) -SoftTorch sort (neuralsort): tensor([-0.8792, -0.1641, 0.2767, 0.8738]) -SoftTorch sort (fast_soft_sort): tensor([-0.7462, -0.1971, 0.2938, 0.8569]) -SoftTorch sort (ot): tensor([-0.7324, -0.2396, 0.3286, 0.7434]) -SoftTorch sort (sorting_network): tensor([-0.7999, -0.2672, 0.3847, 0.7863]) - -Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (hard): tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (smooth): tensor([-0.6057, -0.1997, 0.2729, 0.6281]) -SoftTorch sort (c0): tensor([-1.0000, -0.6313, 0.6525, 0.9824]) -SoftTorch sort (c1): tensor([-0.9982, -0.5432, 0.5814, 0.9837]) -SoftTorch sort (c2): tensor([-0.9978, -0.4905, 0.5425, 0.9903]) +Hard grad: tensor([ 0.2103, 0.1772, -0.8305]) +Soft grad: tensor([ 0.0731, 0.7100, -0.2970]) +Learned w: tensor([ 1.0000, -2.0000, 0.5000]) (true: tensor([ 1.0000, -2.0000, 0.5000]) ) ``` +**Top-k feature selection:** +Discover which features of a trained model are important. ```python -# Operators returning indices -print("\nTorch argmax:", torch.argmax(x)) -print("SoftTorch argmax (hard mode):", st.argmax(x, mode="hard")) -print("SoftTorch argmax (soft mode):", st.argmax(x)) - -print("\nTorch argmin:", torch.argmin(x)) -print("SoftTorch argmin (hard mode):", st.argmin(x, mode="hard")) -print("SoftTorch argmin (soft mode):", st.argmin(x)) - -print("\nTorch argquantile:", "Not implemented in standard PyTorch") -print("SoftTorch argquantile (hard mode):", st.argquantile(x, q=0.2, mode="hard")) -print("SoftTorch argquantile (soft mode):", st.argquantile(x, q=0.2)) - -print("\nTorch argmedian:", torch.median(x, dim=0).indices) -print("SoftTorch argmedian (hard mode):", st.median(x, mode="hard", dim=0).indices) -print("SoftTorch argmedian (soft mode):", st.median(x, dim=0).indices) - -print("\nTorch argsort:", torch.argsort(x)) -print("SoftTorch argsort (hard mode):", st.argsort(x, mode="hard")) -print("SoftTorch argsort (soft mode):", st.argsort(x)) - -print("\nTorch argtopk:", torch.topk(x, k=3).indices) -print("SoftTorch argtopk (hard mode):", st.topk(x, k=3, mode="hard").indices) -print("SoftTorch argtopk (soft mode):", st.topk(x, k=3).indices) -``` -``` -Torch argmax: tensor(3) -SoftTorch argmax (hard mode): tensor([0., 0., 0., 1.]) -SoftTorch argmax (soft mode): tensor([0.0215, 0.0022, 0.1176, 0.8586]) - -Torch argmin: tensor(1) -SoftTorch argmin (hard mode): tensor([0., 1., 0., 0.]) -SoftTorch argmin (soft mode): tensor([0.0922, 0.8885, 0.0169, 0.0023]) - -Torch argquantile: Not implemented in standard PyTorch -SoftTorch argquantile (hard mode): tensor([0.6000, 0.4000, 0.0000, 0.0000]) -SoftTorch argquantile (soft mode): tensor([0.5403, 0.3693, 0.0902, 0.0001]) - -Torch argmedian: tensor(0) -SoftTorch argmedian (hard mode): tensor([1., 0., 0., 0.]) -SoftTorch argmedian (soft mode): tensor([0.8009, 0.0491, 0.1498, 0.0002]) - -Torch argsort: tensor([1, 0, 2, 3]) -SoftTorch argsort (hard mode): tensor([[0., 1., 0., 0.], - [1., 0., 0., 0.], - [0., 0., 1., 0.], - [0., 0., 0., 1.]]) -SoftTorch argsort (soft mode): tensor([[0.1494, 0.8496, 0.0009, 0.0000], - [0.8009, 0.0491, 0.1498, 0.0002], - [0.1418, 0.0001, 0.7899, 0.0681], - [0.0011, 0.0000, 0.1784, 0.8205]]) - -Torch argtopk: tensor([3, 2, 0]) -SoftTorch argtopk (hard mode): tensor([[0., 0., 0., 1.], - [0., 0., 1., 0.], - [1., 0., 0., 0.]]) -SoftTorch argtopk (soft mode): tensor([[0.0011, 0.0000, 0.1784, 0.8205], - [0.1418, 0.0001, 0.7899, 0.0681], - [0.8009, 0.0491, 0.1498, 0.0002]]) -``` - +n_features, k = 10, 3 +torch.manual_seed(42) +X = torch.randn(100, n_features) +w_model = torch.tensor([0, 2.0, 0, -1.5, 0, 0, 0, 5.0, 0, 0]) +y = X @ w_model + 0.1 * torch.randn(100) + +def feature_selection_loss(g, X, y, w_model, mode="smooth"): + _, soft_idx = st.topk(g, k=k, mode=mode, gated_grad=False) + mask = soft_idx.sum(dim=0) + y_pred = (X * mask) @ w_model + return torch.mean(st.abs(y_pred - y)) + +g = torch.zeros(n_features, requires_grad=True) +hard_loss = feature_selection_loss(g, X, y, w_model, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, g)[0] if hard_loss.requires_grad else torch.zeros_like(g)) +soft_loss = feature_selection_loss(g, X, y, w_model, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, g)[0]) + +g = torch.zeros(n_features) +for _ in range(5): + g.requires_grad_(True) + loss = feature_selection_loss(g, X, y, w_model) + g_grad = torch.autograd.grad(loss, g)[0] + g = (g - 0.001 * g_grad).detach() +print("Selected features:", torch.topk(g, k=k).indices) +``` +``` +Hard grad: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) +Soft grad: tensor([ 2359.3386, 62.9980, 2359.3386, -890.2852, 2359.3386, + 2359.3386, 2359.3386, -15688.0829, 2359.3386, 2359.3386]) +Selected features: tensor([7, 3, 1]) +``` + +**Differentiable threshold filtering:** +Learn a threshold that gates inputs. ```python -y = torch.tensor([0.2, -0.5, 0.5, -1.0]) - -# Comparison operators -print("\nTorch greater:", torch.greater(x, y)) -print("SoftTorch greater (hard mode):", st.greater(x, y, mode="hard")) -print("SoftTorch greater (soft mode):", st.greater(x, y)) - -print("\nTorch greater equal:", torch.greater_equal(x, y)) -print("SoftTorch greater equal (hard mode):", st.greater_equal(x, y, mode="hard")) -print("SoftTorch greater equal (soft mode):", st.greater_equal(x, y)) - -print("\nTorch less:", torch.less(x, y)) -print("SoftTorch less (hard mode):", st.less(x, y, mode="hard")) -print("SoftTorch less (soft mode):", st.less(x, y)) +x = torch.tensor([0.2, 0.8, 0.5, 1.2, 0.1]) +target_sum = 2.0 # sum of values above threshold = 2.0 (i.e. 0.8 + 1.2) -print("\nTorch less equal:", torch.less_equal(x, y)) -print("SoftTorch less equal (hard mode):", st.less_equal(x, y, mode="hard")) -print("SoftTorch less equal (soft mode):", st.less_equal(x, y)) +def filter_loss(t, x, target, mode="smooth"): + mask = st.greater(x, t, mode=mode) + return (torch.sum(mask * x) - target) ** 2 -print("\nTorch eq:", torch.eq(x, y)) -print("SoftTorch eq (hard mode):", st.eq(x, y, mode="hard")) -print("SoftTorch eq (soft mode):", st.eq(x, y)) +t = torch.tensor(0.0, requires_grad=True) +hard_loss = filter_loss(t, x, target_sum, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, t)[0] if hard_loss.requires_grad else torch.zeros_like(t)) +soft_loss = filter_loss(t, x, target_sum, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, t)[0]) -print("\nTorch not equal:", torch.not_equal(x, y)) -print("SoftTorch not equal (hard mode):", st.not_equal(x, y, mode="hard")) -print("SoftTorch not equal (soft mode):", st.not_equal(x, y)) - -print("\nTorch isclose:", torch.isclose(x, y)) -print("SoftTorch isclose (hard mode):", st.isclose(x, y, mode="hard")) -print("SoftTorch isclose (soft mode):", st.isclose(x, y)) +t = torch.tensor(0.0) +for _ in range(20): + t.requires_grad_(True) + loss = filter_loss(t, x, target_sum) + t_grad = torch.autograd.grad(loss, t)[0] + t = (t - 0.1 * t_grad).detach() +print("Learned threshold:", t) ``` ``` -Torch greater: tensor([False, False, False, True]) -SoftTorch greater (hard mode): tensor([0., 0., 0., 1.]) -SoftTorch greater (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000]) - -Torch greater equal: tensor([False, False, False, True]) -SoftTorch greater equal (hard mode): tensor([0., 0., 0., 1.]) -SoftTorch greater equal (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000]) - -Torch less: tensor([ True, True, True, False]) -SoftTorch less (hard mode): tensor([1., 1., 1., 0.]) -SoftTorch less (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000]) - -Torch less equal: tensor([ True, True, True, False]) -SoftTorch less equal (hard mode): tensor([1., 1., 1., 0.]) -SoftTorch less equal (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000]) - -Torch eq: tensor([False, False, False, False]) -SoftTorch eq (hard mode): tensor([0., 0., 0., 0.]) -SoftTorch eq (soft mode): tensor([0.0414, 0.0143, 0.3580, 0.0000]) - -Torch not equal: tensor([True, True, True, True]) -SoftTorch not equal (hard mode): tensor([1., 1., 1., 1.]) -SoftTorch not equal (soft mode): tensor([0.9586, 0.9857, 0.6420, 1.0000]) - -Torch isclose: tensor([False, False, False, False]) -SoftTorch isclose (hard mode): tensor([0., 0., 0., 0.]) -SoftTorch isclose (soft mode): tensor([0.0414, 0.0143, 0.3580, 0.0000]) +Hard grad: tensor(0.) +Soft grad: tensor(-0.6600) +Learned threshold: tensor(0.6211) ``` +**Rule-based classifier:** +Learn decision boundaries `[lo, hi]` for a rule using soft logic and straight-through estimation. The rule is true if any element of a feature is inside `[lo, hi]`. ```python -# Logical operators -fuzzy_a = torch.tensor([0.1, 0.2, 0.8, 1.0]) -fuzzy_b = torch.tensor([0.7, 0.3, 0.1, 0.9]) -bool_a = fuzzy_a >= 0.5 -bool_b = fuzzy_b >= 0.5 - -print("\nTorch AND:", torch.logical_and(bool_a, bool_b)) -print("SoftTorch AND:", st.logical_and(fuzzy_a, fuzzy_b)) - -print("\nTorch OR:", torch.logical_or(bool_a, bool_b)) -print("SoftTorch OR:", st.logical_or(fuzzy_a, fuzzy_b)) +x = torch.tensor([[0.2, 0.8], [0.5, 0.3], [0.9, 0.1], [0.4, 0.7], [0.1, 0.4], [0.2, 0.7], [0.4, 0.1], [0.4, 0.7], + [0.7, 0.29], [0.3, 0.3], [0.61, 0.25], [0.4, 0.6], [0.0, 0.1], [0.5, 0.3], [0.4, 0.9], [0.1, 0.57]]) +labels = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0]) -print("\nTorch NOT:", torch.logical_not(bool_a)) -print("SoftTorch NOT:", st.logical_not(fuzzy_a)) +@st.st +def rule_loss(params, x, labels, mode="smooth"): + lo, hi = params[0], params[1] + above = st.greater(x, lo, mode=mode) + below = st.less(x, hi, mode=mode) + in_range = st.logical_and(above, below) + preds = st.any(in_range, dim=-1) + return ((preds - labels) ** 2).sum() -print("\nTorch XOR:", torch.logical_xor(bool_a, bool_b)) -print("SoftTorch XOR:", st.logical_xor(fuzzy_a, fuzzy_b)) +params = torch.tensor([0.0, 1.0], requires_grad=True) +hard_loss = rule_loss(params, x, labels, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, params)[0] if hard_loss.requires_grad else torch.zeros_like(params)) +soft_loss = rule_loss(params, x, labels, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, params)[0]) -print("\nTorch ALL:", torch.all(bool_a)) -print("SoftTorch ALL:", st.all(fuzzy_a)) - -print("\nTorch ANY:", torch.any(bool_a)) -print("SoftTorch ANY:", st.any(fuzzy_a)) - -# Selection operators -print("\nTorch Where:", torch.where(bool_a, x, y)) -print("SoftTorch Where:", st.where(fuzzy_a, x, y)) +params = torch.tensor([0.0, 1.0]) +for _ in range(20): + params.requires_grad_(True) + loss = rule_loss(params, x, labels) + p_grad = torch.autograd.grad(loss, params)[0] + params = (params - 0.01 * p_grad).detach() +print("Learned [lo, hi]:", params) ``` ``` -Torch AND: tensor([False, False, False, True]) -SoftTorch AND: tensor([0.0700, 0.0600, 0.0800, 0.9000]) - -Torch OR: tensor([ True, False, True, True]) -SoftTorch OR: tensor([0.7300, 0.4400, 0.8200, 1.0000]) - -Torch NOT: tensor([ True, True, False, False]) -SoftTorch NOT: tensor([0.9000, 0.8000, 0.2000, 0.0000]) - -Torch XOR: tensor([ True, False, True, False]) -SoftTorch XOR: tensor([0.6411, 0.3464, 0.7256, 0.1000]) - -Torch ALL: tensor(False) -SoftTorch ALL: tensor(0.0160) - -Torch ANY: tensor(True) -SoftTorch ANY: tensor(1.) - -Torch Where: tensor([ 0.2000, -0.5000, 0.3000, 1.0000]) -SoftTorch Where: tensor([ 0.1600, -0.6000, 0.3400, 1.0000]) +Hard grad: tensor([0., 0.]) +Soft grad: tensor([-4.2777, 1.4152]) +Learned [lo, hi]: tensor([0.2925, 0.5999]) ``` -```python -# Straight-through operators: Use hard function on forward and soft on backward -print("Straight-through ReLU:", st.relu_st(x)) -print("Straight-through sort:", st.sort_st(x).values) -print("Straight-through argtopk:", st.topk_st(x, k=3).indices) -print("Straight-through greater:", st.greater_st(x, y)) -# And many more... -``` -``` -Straight-through ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000]) -Straight-through sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -Straight-through argtopk: tensor([[0., 0., 0., 1.], - [0., 0., 1., 0.], - [1., 0., 0., 0.]]) -Straight-through greater: tensor([0., 0., 0., 1.]) -``` +Optimization trajectories ## Citation -If this library helped your academic work, please consider citing: +If this library helped your academic work, please consider citing: ([arXiv link](https://arxiv.org/abs/2603.08824)) ```bibtex @article{paulus2026softjax, @@ -400,14 +201,14 @@ If this library helped your academic work, please consider citing: } ``` -Also consider starring the project [on GitHub](https://github.com/a-paulus/softtorch)! +(Also consider starring the project [on GitHub](https://github.com/a-paulus/softtorch)) Special thanks and credit go to [Patrick Kidger](https://kidger.site) for the awesome [JAX repositories](https://github.com/patrick-kidger) that served as the basis for the documentation of this project. ## Feedback -This project is still relatively young, if you have any suggestions for improvement or other feedback, please [reach out](mailto:paulus.anselm@gmail.com) or raise a GitHub issue! +If you have any suggestions for improvement or other feedback, please [reach out](mailto:paulus.anselm@gmail.com) or raise a GitHub issue! ## See also diff --git a/docs/.citation.md b/docs/.citation.md deleted file mode 100644 index b8ab08a..0000000 --- a/docs/.citation.md +++ /dev/null @@ -1,13 +0,0 @@ -If this library helped your academic work, please consider citing: - -```bibtex -@article{paulus2026softjax, - title={{SoftJAX} \& {SoftTorch}: Empowering Automatic Differentiation Libraries with Informative Gradients}, - author={Paulus, Anselm and Geist, A.\ Ren\'e and Musil, V\'it and Hoffmann, Sebastian and Beker, Onur and Martius, Georg}, - journal={arXiv preprint}, - year={2026}, - eprint={2603.08824} -} -``` - -Also consider starring the project [on GitHub](https://github.com/a-paulus/softtorch)! diff --git a/docs/quick_example.py b/docs/examples/long_example.py similarity index 100% rename from docs/quick_example.py rename to docs/examples/long_example.py diff --git a/docs/manifold_points.ipynb b/docs/examples/manifold_points.ipynb similarity index 100% rename from docs/manifold_points.ipynb rename to docs/examples/manifold_points.ipynb diff --git a/docs/paper_examples.py b/docs/examples/paper_examples.py similarity index 100% rename from docs/paper_examples.py rename to docs/examples/paper_examples.py diff --git a/docs/examples/quick_example.py b/docs/examples/quick_example.py new file mode 100644 index 0000000..9572c18 --- /dev/null +++ b/docs/examples/quick_example.py @@ -0,0 +1,200 @@ +import matplotlib.pyplot as plt +import torch +import softtorch as st + +torch.set_printoptions(precision=4, sci_mode=False) +torch.set_default_dtype(torch.float64) + + +# 1. Median regression +# Minimize the median absolute residual to be robust to outliers. + +torch.manual_seed(0) +X = torch.randn(20, 3) +w_true = torch.tensor([1.0, -2.0, 0.5]) +y = X @ w_true +y[0] = 1e6 # inject outlier + + +def median_regression_loss(w, X, y, mode="smooth"): + residuals = y - X @ w + return st.median(st.abs(residuals, mode=mode), mode=mode) + + +w = torch.zeros(3, requires_grad=True) +hard_loss = median_regression_loss(w, X, y, mode="hard") +print("=== 1. Robust median regression ===") +print("Hard grad:", torch.autograd.grad(hard_loss, w)[0]) +soft_loss = median_regression_loss(w, X, y, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, w)[0]) + +ws = [] +w = torch.zeros(3) +for _ in range(50): + ws.append(w.tolist()) + w.requires_grad_(True) + loss = median_regression_loss(w, X, y) + g = torch.autograd.grad(loss, w)[0] + w = (w - 0.1 * g).detach() +print("Learned w:", w, " (true:", w_true, ")") + + +# 2. Top-k feature selection +# Discover which features of a trained model are important. +# 10 features total, only 3 informative — learn gating scores to find them. + +n_features, k = 10, 3 +torch.manual_seed(42) +X = torch.randn(100, n_features) +w_model = torch.tensor([0, 2.0, 0, -1.5, 0, 0, 0, 5.0, 0, 0]) +y = X @ w_model + 0.1 * torch.randn(100) + + +def feature_selection_loss(g, X, y, w_model, mode="smooth"): + _, soft_idx = st.topk(g, k=k, mode=mode, gated_grad=False) + mask = soft_idx.sum(dim=0) + y_pred = (X * mask) @ w_model + return torch.mean(st.abs(y_pred - y)) + + +g = torch.zeros(n_features, requires_grad=True) +print("\n=== 2. Top-k feature selection ===") +hard_loss = feature_selection_loss(g, X, y, w_model, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, g)[0] if hard_loss.requires_grad else torch.zeros_like(g)) +soft_loss = feature_selection_loss(g, X, y, w_model, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, g)[0]) + +gs = [] +g = torch.zeros(n_features) +for _ in range(5): + gs.append(g.tolist()) + g.requires_grad_(True) + loss = feature_selection_loss(g, X, y, w_model) + g_grad = torch.autograd.grad(loss, g)[0] + g = (g - 0.001 * g_grad).detach() +print("Selected features:", torch.topk(g, k=k).indices) + + +# 3. Differentiable filter +# Learn a threshold that gates inputs. + +x_filt = torch.tensor([0.2, 0.8, 0.5, 1.2, 0.1]) +target_sum = 2.0 # sum of values above threshold should equal 2.0 (= 0.8 + 1.2) + + +def filter_loss(t, x, target, mode="smooth"): + mask = st.greater(x, t, mode=mode) + return (torch.sum(mask * x) - target) ** 2 + + +t = torch.tensor(0.0, requires_grad=True) +print("\n=== 3. Differentiable threshold filtering ===") +hard_loss = filter_loss(t, x_filt, target_sum, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, t)[0] if hard_loss.requires_grad else torch.zeros_like(t)) +soft_loss = filter_loss(t, x_filt, target_sum, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, t)[0]) + +ts = [] +t = torch.tensor(0.0) +for _ in range(20): + ts.append(float(t)) + t.requires_grad_(True) + loss = filter_loss(t, x_filt, target_sum) + t_grad = torch.autograd.grad(loss, t)[0] + t = (t - 0.1 * t_grad).detach() +print("Learned threshold:", t) + + +# 4. Differentiable rule-based classifier +# Learn decision boundaries: classify positive if ANY feature is in [lo, hi]. +# The rule is true if any element of a feature is inside `[lo, hi]`. +x_rules = torch.tensor([[0.2, 0.8], [0.5, 0.3], [0.9, 0.1], [0.4, 0.7], + [0.1, 0.4], [0.2, 0.7], [0.4, 0.1], [0.4, 0.7], + [0.7, 0.29], [0.3, 0.3], [0.61, 0.25], [0.4, 0.6], + [0.0, 0.1], [0.5, 0.3], [0.4, 0.9], [0.1, 0.57]]) +labels = torch.tensor([0.0, 1.0, 0.0, 1.0, + 1.0, 0.0, 1.0, 1.0, + 0.0, 1.0, 0.0, 1.0, + 0.0, 1.0, 1.0, 1.0]) + + +@st.st +def rule_loss(params, x, labels, mode="smooth"): + lo, hi = params[0], params[1] + above = st.greater(x, lo, mode=mode) + below = st.less(x, hi, mode=mode) + in_range = st.logical_and(above, below) + preds = st.any(in_range, dim=-1) + return ((preds - labels) ** 2).sum() + + +params = torch.tensor([0.0, 1.0], requires_grad=True) +print("\n=== 4. Differentiable rule-based classifier ===") +hard_loss = rule_loss(params, x_rules, labels, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, params)[0] if hard_loss.requires_grad else torch.zeros_like(params)) +soft_loss = rule_loss(params, x_rules, labels, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, params)[0]) + +params_hist = [] +params = torch.tensor([0.0, 1.0]) +for _ in range(20): + params_hist.append(params.tolist()) + params.requires_grad_(True) + loss = rule_loss(params, x_rules, labels) + p_grad = torch.autograd.grad(loss, params)[0] + params = (params - 0.01 * p_grad).detach() +print("Learned [lo, hi]:", params) + + +# ── Plot ───────────────────────────────────────────────────────────────────── +palette = ["#00bfff", "#e7a1e5", "#6dd1ac", "#e1be6a", "#368f80", "#889fd9", "#f4836d", "#cecece"] +informative = {i for i, v in enumerate(w_model) if v != 0} + +fig, axes = plt.subplots(1, 4, figsize=(8, 2.5)) + +for ax in axes: + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.tick_params(labelsize=7) + ax.set_xlabel("Iteration", fontsize=7) + ax.yaxis.set_major_locator(plt.MaxNLocator(3)) + ax.margins(x=0) + +ws = torch.tensor(ws) +for i in range(ws.shape[1]): + axes[0].plot(ws[:, i], color=palette[i], label=f"w[{i}]") + axes[0].axhline(w_true[i], color=palette[i], ls="--", alpha=0.3) +axes[0].set_title("Median regression", fontsize=8) +axes[0].legend(fontsize=6) + +gs = torch.tensor(gs) +for i in range(gs.shape[1]): + if i in informative: + if i == 1: + kw = {"lw": 1.5, "color": "#6dd1ac", "label": "Informative"} + else: + kw = {"lw": 1.5, "color": "#6dd1ac", "label": None} + else: + if i == 4: + kw = {"alpha": 0.2, "color": "#889fd9", "label": "Uninformative"} + else: + kw = {"alpha": 0.2, "color": "#889fd9", "label": None} + axes[1].plot(gs[:, i], **kw) +axes[1].set_title("Top-k feature selection", fontsize=8) +axes[1].legend(fontsize=6, title="Feature scores", title_fontsize=6) + +axes[2].plot(ts, color=palette[0]) +for xi in x_filt: + axes[2].axhline(xi, ls="--", color=palette[-1], alpha=0.5) +axes[2].set_title("Threshold filtering", fontsize=8) + +params_hist = torch.tensor(params_hist) +axes[3].plot(params_hist[:, 1], color=palette[0], label="higher bound") +axes[3].plot(params_hist[:, 0], color=palette[2], label="lower bound") +axes[3].axhline(0.3, ls="--", color=palette[2], alpha=0.5) +axes[3].axhline(0.6, ls="--", color=palette[0], alpha=0.5) +axes[3].set_title("Rule classifier", fontsize=8) +axes[3].legend(fontsize=6) + +fig.tight_layout() +fig.savefig("docs/examples/quick_example_optimization.svg", bbox_inches="tight", transparent=True) diff --git a/docs/examples/quick_example_optimization.svg b/docs/examples/quick_example_optimization.svg new file mode 100644 index 0000000..7c0f8ea --- /dev/null +++ b/docs/examples/quick_example_optimization.svg @@ -0,0 +1,2096 @@ + + + + + + + + 2026-04-07T13:18:58.074655 + image/svg+xml + + + Matplotlib v3.10.8, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/index.md b/docs/index.md index d5eb303..8bef4dc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,9 +3,12 @@

-# SoftTorch +# Soft differentiable programming in PyTorch -## In a nutshell +Looking for JAX? See [SoftJAX](https://github.com/a-paulus/softjax). + + +## What is SoftTorch? SoftTorch provides soft differentiable drop-in replacements for traditionally non-differentiable functions in [PyTorch](https://pytorch.org), including @@ -19,8 +22,7 @@ All operators offer multiple modes (controlling smoothness or boundedness of the All operators also support straight-through estimation, using the non-differentiable function in the forward pass and the soft relaxation in the backward pass. -SoftTorch functions are drop-in replacements for their non-differentiable PyTorch counterparts. -Special care is needed for functions operating on indices, as we relax discrete indices into distributions over indices, which modifies the shape of returned/accepted values. +*Note, while SoftTorch is designed to provide direct drop-in replacements for PyTorch's operators, soft axis-wise operators return a probability distribution over indices (instead of an index), effectively changing the shape of the function's output.* ## Installation @@ -30,353 +32,163 @@ pip install softtorch ``` -## Quick example -```python -import torch -import softtorch as st - -x = torch.tensor([-0.2, -1.0, 0.3, 1.0]) - -# Elementwise functions -print("\nTorch absolute:", torch.abs(x)) -print("SoftTorch absolute (hard mode):", st.abs(x, mode="hard")) -print("SoftTorch absolute (soft mode):", st.abs(x)) - -print("\nTorch clamp:", torch.clamp(x, -0.5, 0.5)) -print("SoftTorch clamp (hard mode):", st.clamp(x, -0.5, 0.5, mode="hard")) -print("SoftTorch clamp (soft mode):", st.clamp(x, -0.5, 0.5)) - -print("\nTorch heaviside:", torch.heaviside(x, torch.tensor(0.5))) -print("SoftTorch heaviside (hard mode):", st.heaviside(x, mode="hard")) -print("SoftTorch heaviside (soft mode):", st.heaviside(x)) - -print("\nTorch ReLU:", torch.nn.functional.relu(x)) -print("SoftTorch ReLU (hard mode):", st.relu(x, mode="hard")) -print("SoftTorch ReLU (soft mode):", st.relu(x)) - -print("\nTorch round:", torch.round(x)) -print("SoftTorch round (hard mode):", st.round(x, mode="hard")) -print("SoftTorch round (soft mode):", st.round(x)) - -print("\nTorch sign:", torch.sign(x)) -print("SoftTorch sign (hard mode):", st.sign(x, mode="hard")) -print("SoftTorch sign (soft mode):", st.sign(x)) -``` -``` -Torch absolute: tensor([0.2000, 1.0000, 0.3000, 1.0000]) -SoftTorch absolute (hard mode): tensor([0.2000, 1.0000, 0.3000, 1.0000]) -SoftTorch absolute (soft mode): tensor([0.1523, 0.9999, 0.2715, 0.9999]) - -Torch clamp: tensor([-0.2000, -0.5000, 0.3000, 0.5000]) -SoftTorch clamp (hard mode): tensor([-0.2000, -0.5000, 0.3000, 0.5000]) -SoftTorch clamp (soft mode): tensor([-0.1952, -0.4993, 0.2873, 0.4993]) - -Torch heaviside: tensor([0., 0., 1., 1.]) -SoftTorch heaviside (hard mode): tensor([0., 0., 1., 1.]) -SoftTorch heaviside (soft mode): tensor([0.1192, 0.0000, 0.9526, 1.0000]) - -Torch ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000]) -SoftTorch ReLU (hard mode): tensor([0.0000, 0.0000, 0.3000, 1.0000]) -SoftTorch ReLU (soft mode): tensor([0.0127, 0.0000, 0.3049, 1.0000]) - -Torch round: tensor([-0., -1., 0., 1.]) -SoftTorch round (hard mode): tensor([-0., -1., 0., 1.]) -SoftTorch round (soft mode): tensor([-0.0465, -1.0000, 0.1189, 1.0000]) - -Torch sign: tensor([-1., -1., 1., 1.]) -SoftTorch sign (hard mode): tensor([-1., -1., 1., 1.]) -SoftTorch sign (soft mode): tensor([-0.7616, -0.9999, 0.9051, 0.9999]) -``` +## Quick examples +**Robust median regression:** +Minimize the median absolute residual to be robust to outliers. ```python -# Tensor-valued operators -print("\nTorch max:", torch.max(x)) -print("SoftTorch max (hard mode):", st.max(x, mode="hard")) -print("SoftTorch max (soft mode):", st.max(x)) - -print("\nTorch min:", torch.min(x)) -print("SoftTorch min (hard mode):", st.min(x, mode="hard")) -print("SoftTorch min (soft mode):", st.min(x)) - -print("\nTorch sort:", torch.sort(x).values) -print("SoftTorch sort (hard mode):", st.sort(x, mode="hard").values) -print("SoftTorch sort (soft mode):", st.sort(x).values) - -print("\nTorch quantile:", torch.quantile(x, q=0.2)) -print("SoftTorch quantile (hard mode):", st.quantile(x, q=0.2, mode="hard")) -print("SoftTorch quantile (soft mode):", st.quantile(x, q=0.2)) - -print("\nTorch median:", torch.median(x)) -print("SoftTorch median (hard mode):", st.median(x, mode="hard")) -print("SoftTorch median (soft mode):", st.median(x)) - -print("\nTorch topk:", torch.topk(x, k=3).values) -print("SoftTorch topk (hard mode):", st.topk(x, k=3, mode="hard").values) -print("SoftTorch topk (soft mode):", st.topk(x, k=3).values) - -print("\nTorch rank:", torch.argsort(torch.argsort(x))) -print("SoftTorch rank (hard mode):", st.rank(x, mode="hard", descending=False)) -print("SoftTorch rank (soft mode):", st.rank(x, descending=False)) -``` -``` -Torch max: tensor(1.) -SoftTorch max (hard mode): tensor(1.) -SoftTorch max (soft mode): tensor(0.8874) +import torch, softtorch as st -Torch min: tensor(-1.) -SoftTorch min (hard mode): tensor(-1.) -SoftTorch min (soft mode): tensor(-0.8996) +torch.manual_seed(0) +X = torch.randn(20, 3) +w_true = torch.tensor([1.0, -2.0, 0.5]) +y = X @ w_true +y[0] = 1e6 # inject outlier -Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (hard mode): tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (soft mode): tensor([-0.8792, -0.1641, 0.2767, 0.8738]) +def median_regression_loss(w, X, y, mode="smooth"): + residuals = y - X @ w + return st.median(st.abs(residuals, mode=mode), mode=mode) -Torch quantile: tensor(-0.5200) -SoftTorch quantile (hard mode): tensor(-0.5200) -SoftTorch quantile (soft mode): tensor(-0.4501) +w = torch.zeros(3, requires_grad=True) +hard_loss = median_regression_loss(w, X, y, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, w)[0]) +soft_loss = median_regression_loss(w, X, y, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, w)[0]) -Torch median: tensor(-0.2000) -SoftTorch median (hard mode): tensor(-0.2000) -SoftTorch median (soft mode): tensor(-0.1641) - -Torch topk: tensor([ 1.0000, 0.3000, -0.2000]) -SoftTorch topk (hard mode): tensor([ 1.0000, 0.3000, -0.2000]) -SoftTorch topk (soft mode): tensor([ 0.8738, 0.2767, -0.1641]) - -Torch rank: tensor([1, 0, 2, 3]) -SoftTorch rank (hard mode): tensor([2., 1., 3., 4.]) -SoftTorch rank (soft mode): tensor([1.9950, 1.0548, 3.0239, 3.9228]) +w = torch.zeros(3) +for _ in range(50): + w.requires_grad_(True) + loss = median_regression_loss(w, X, y) + g = torch.autograd.grad(loss, w)[0] + w = (w - 0.1 * g).detach() +print("Learned w:", w, " (true:", w_true, ")") ``` - -```python -# Sort: sweep over methods -print("\nTorch sort:", torch.sort(x).values) -print("SoftTorch sort (softsort):", st.sort(x, method="softsort", softness=0.1).values) -print("SoftTorch sort (neuralsort):", st.sort(x, method="neuralsort", softness=0.1).values) -print("SoftTorch sort (fast_soft_sort):", st.sort(x, method="fast_soft_sort", softness=2.0).values) -print("SoftTorch sort (ot):", st.sort(x, method="ot", softness=0.1).values) -print("SoftTorch sort (sorting_network):", st.sort(x, method="sorting_network", softness=0.1).values) - -# Sort: sweep over modes -print("\nTorch sort:", torch.sort(x).values) -for mode in ["hard", "smooth", "c0", "c1", "c2"]: - print(f"SoftTorch sort ({mode}):", st.sort(x, softness=0.5, mode=mode).values) ``` -``` -Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (softsort): tensor([-0.8996, -0.1705, 0.2847, 0.8874]) -SoftTorch sort (neuralsort): tensor([-0.8792, -0.1641, 0.2767, 0.8738]) -SoftTorch sort (fast_soft_sort): tensor([-0.7462, -0.1971, 0.2938, 0.8569]) -SoftTorch sort (ot): tensor([-0.7324, -0.2396, 0.3286, 0.7434]) -SoftTorch sort (sorting_network): tensor([-0.7999, -0.2672, 0.3847, 0.7863]) - -Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (hard): tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (smooth): tensor([-0.6057, -0.1997, 0.2729, 0.6281]) -SoftTorch sort (c0): tensor([-1.0000, -0.6313, 0.6525, 0.9824]) -SoftTorch sort (c1): tensor([-0.9982, -0.5432, 0.5814, 0.9837]) -SoftTorch sort (c2): tensor([-0.9978, -0.4905, 0.5425, 0.9903]) +Hard grad: tensor([ 0.2103, 0.1772, -0.8305]) +Soft grad: tensor([ 0.0731, 0.7100, -0.2970]) +Learned w: tensor([ 1.0000, -2.0000, 0.5000]) (true: tensor([ 1.0000, -2.0000, 0.5000]) ) ``` +**Top-k feature selection:** +Discover which features of a trained model are important. ```python -# Operators returning indices -print("\nTorch argmax:", torch.argmax(x)) -print("SoftTorch argmax (hard mode):", st.argmax(x, mode="hard")) -print("SoftTorch argmax (soft mode):", st.argmax(x)) - -print("\nTorch argmin:", torch.argmin(x)) -print("SoftTorch argmin (hard mode):", st.argmin(x, mode="hard")) -print("SoftTorch argmin (soft mode):", st.argmin(x)) - -print("\nTorch argquantile:", "Not implemented in standard PyTorch") -print("SoftTorch argquantile (hard mode):", st.argquantile(x, q=0.2, mode="hard")) -print("SoftTorch argquantile (soft mode):", st.argquantile(x, q=0.2)) - -print("\nTorch argmedian:", torch.median(x, dim=0).indices) -print("SoftTorch argmedian (hard mode):", st.median(x, mode="hard", dim=0).indices) -print("SoftTorch argmedian (soft mode):", st.median(x, dim=0).indices) - -print("\nTorch argsort:", torch.argsort(x)) -print("SoftTorch argsort (hard mode):", st.argsort(x, mode="hard")) -print("SoftTorch argsort (soft mode):", st.argsort(x)) - -print("\nTorch argtopk:", torch.topk(x, k=3).indices) -print("SoftTorch argtopk (hard mode):", st.topk(x, k=3, mode="hard").indices) -print("SoftTorch argtopk (soft mode):", st.topk(x, k=3).indices) -``` -``` -Torch argmax: tensor(3) -SoftTorch argmax (hard mode): tensor([0., 0., 0., 1.]) -SoftTorch argmax (soft mode): tensor([0.0215, 0.0022, 0.1176, 0.8586]) - -Torch argmin: tensor(1) -SoftTorch argmin (hard mode): tensor([0., 1., 0., 0.]) -SoftTorch argmin (soft mode): tensor([0.0922, 0.8885, 0.0169, 0.0023]) - -Torch argquantile: Not implemented in standard PyTorch -SoftTorch argquantile (hard mode): tensor([0.6000, 0.4000, 0.0000, 0.0000]) -SoftTorch argquantile (soft mode): tensor([0.5403, 0.3693, 0.0902, 0.0001]) - -Torch argmedian: tensor(0) -SoftTorch argmedian (hard mode): tensor([1., 0., 0., 0.]) -SoftTorch argmedian (soft mode): tensor([0.8009, 0.0491, 0.1498, 0.0002]) - -Torch argsort: tensor([1, 0, 2, 3]) -SoftTorch argsort (hard mode): tensor([[0., 1., 0., 0.], - [1., 0., 0., 0.], - [0., 0., 1., 0.], - [0., 0., 0., 1.]]) -SoftTorch argsort (soft mode): tensor([[0.1494, 0.8496, 0.0009, 0.0000], - [0.8009, 0.0491, 0.1498, 0.0002], - [0.1418, 0.0001, 0.7899, 0.0681], - [0.0011, 0.0000, 0.1784, 0.8205]]) - -Torch argtopk: tensor([3, 2, 0]) -SoftTorch argtopk (hard mode): tensor([[0., 0., 0., 1.], - [0., 0., 1., 0.], - [1., 0., 0., 0.]]) -SoftTorch argtopk (soft mode): tensor([[0.0011, 0.0000, 0.1784, 0.8205], - [0.1418, 0.0001, 0.7899, 0.0681], - [0.8009, 0.0491, 0.1498, 0.0002]]) -``` - +n_features, k = 10, 3 +torch.manual_seed(42) +X = torch.randn(100, n_features) +w_model = torch.tensor([0, 2.0, 0, -1.5, 0, 0, 0, 5.0, 0, 0]) +y = X @ w_model + 0.1 * torch.randn(100) + +def feature_selection_loss(g, X, y, w_model, mode="smooth"): + _, soft_idx = st.topk(g, k=k, mode=mode, gated_grad=False) + mask = soft_idx.sum(dim=0) + y_pred = (X * mask) @ w_model + return torch.mean(st.abs(y_pred - y)) + +g = torch.zeros(n_features, requires_grad=True) +hard_loss = feature_selection_loss(g, X, y, w_model, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, g)[0] if hard_loss.requires_grad else torch.zeros_like(g)) +soft_loss = feature_selection_loss(g, X, y, w_model, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, g)[0]) + +g = torch.zeros(n_features) +for _ in range(5): + g.requires_grad_(True) + loss = feature_selection_loss(g, X, y, w_model) + g_grad = torch.autograd.grad(loss, g)[0] + g = (g - 0.001 * g_grad).detach() +print("Selected features:", torch.topk(g, k=k).indices) +``` +``` +Hard grad: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) +Soft grad: tensor([ 2359.3386, 62.9980, 2359.3386, -890.2852, 2359.3386, + 2359.3386, 2359.3386, -15688.0829, 2359.3386, 2359.3386]) +Selected features: tensor([7, 3, 1]) +``` + +**Differentiable threshold filtering:** +Learn a threshold that gates inputs. ```python -y = torch.tensor([0.2, -0.5, 0.5, -1.0]) - -# Comparison operators -print("\nTorch greater:", torch.greater(x, y)) -print("SoftTorch greater (hard mode):", st.greater(x, y, mode="hard")) -print("SoftTorch greater (soft mode):", st.greater(x, y)) - -print("\nTorch greater equal:", torch.greater_equal(x, y)) -print("SoftTorch greater equal (hard mode):", st.greater_equal(x, y, mode="hard")) -print("SoftTorch greater equal (soft mode):", st.greater_equal(x, y)) +x = torch.tensor([0.2, 0.8, 0.5, 1.2, 0.1]) +target_sum = 2.0 # sum of values above threshold = 2.0 (i.e. 0.8 + 1.2) -print("\nTorch less:", torch.less(x, y)) -print("SoftTorch less (hard mode):", st.less(x, y, mode="hard")) -print("SoftTorch less (soft mode):", st.less(x, y)) +def filter_loss(t, x, target, mode="smooth"): + mask = st.greater(x, t, mode=mode) + return (torch.sum(mask * x) - target) ** 2 -print("\nTorch less equal:", torch.less_equal(x, y)) -print("SoftTorch less equal (hard mode):", st.less_equal(x, y, mode="hard")) -print("SoftTorch less equal (soft mode):", st.less_equal(x, y)) +t = torch.tensor(0.0, requires_grad=True) +hard_loss = filter_loss(t, x, target_sum, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, t)[0] if hard_loss.requires_grad else torch.zeros_like(t)) +soft_loss = filter_loss(t, x, target_sum, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, t)[0]) -print("\nTorch eq:", torch.eq(x, y)) -print("SoftTorch eq (hard mode):", st.eq(x, y, mode="hard")) -print("SoftTorch eq (soft mode):", st.eq(x, y)) - -print("\nTorch not equal:", torch.not_equal(x, y)) -print("SoftTorch not equal (hard mode):", st.not_equal(x, y, mode="hard")) -print("SoftTorch not equal (soft mode):", st.not_equal(x, y)) - -print("\nTorch isclose:", torch.isclose(x, y)) -print("SoftTorch isclose (hard mode):", st.isclose(x, y, mode="hard")) -print("SoftTorch isclose (soft mode):", st.isclose(x, y)) +t = torch.tensor(0.0) +for _ in range(20): + t.requires_grad_(True) + loss = filter_loss(t, x, target_sum) + t_grad = torch.autograd.grad(loss, t)[0] + t = (t - 0.1 * t_grad).detach() +print("Learned threshold:", t) ``` ``` -Torch greater: tensor([False, False, False, True]) -SoftTorch greater (hard mode): tensor([0., 0., 0., 1.]) -SoftTorch greater (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000]) - -Torch greater equal: tensor([False, False, False, True]) -SoftTorch greater equal (hard mode): tensor([0., 0., 0., 1.]) -SoftTorch greater equal (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000]) - -Torch less: tensor([ True, True, True, False]) -SoftTorch less (hard mode): tensor([1., 1., 1., 0.]) -SoftTorch less (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000]) - -Torch less equal: tensor([ True, True, True, False]) -SoftTorch less equal (hard mode): tensor([1., 1., 1., 0.]) -SoftTorch less equal (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000]) - -Torch eq: tensor([False, False, False, False]) -SoftTorch eq (hard mode): tensor([0., 0., 0., 0.]) -SoftTorch eq (soft mode): tensor([0.0414, 0.0143, 0.3580, 0.0000]) - -Torch not equal: tensor([True, True, True, True]) -SoftTorch not equal (hard mode): tensor([1., 1., 1., 1.]) -SoftTorch not equal (soft mode): tensor([0.9586, 0.9857, 0.6420, 1.0000]) - -Torch isclose: tensor([False, False, False, False]) -SoftTorch isclose (hard mode): tensor([0., 0., 0., 0.]) -SoftTorch isclose (soft mode): tensor([0.0414, 0.0143, 0.3580, 0.0000]) +Hard grad: tensor(0.) +Soft grad: tensor(-0.6600) +Learned threshold: tensor(0.6211) ``` +**Rule-based classifier:** +Learn decision boundaries `[lo, hi]` for a rule using soft logic and straight-through estimation. The rule is true if any element of a feature is inside `[lo, hi]`. ```python -# Logical operators -fuzzy_a = torch.tensor([0.1, 0.2, 0.8, 1.0]) -fuzzy_b = torch.tensor([0.7, 0.3, 0.1, 0.9]) -bool_a = fuzzy_a >= 0.5 -bool_b = fuzzy_b >= 0.5 - -print("\nTorch AND:", torch.logical_and(bool_a, bool_b)) -print("SoftTorch AND:", st.logical_and(fuzzy_a, fuzzy_b)) - -print("\nTorch OR:", torch.logical_or(bool_a, bool_b)) -print("SoftTorch OR:", st.logical_or(fuzzy_a, fuzzy_b)) - -print("\nTorch NOT:", torch.logical_not(bool_a)) -print("SoftTorch NOT:", st.logical_not(fuzzy_a)) - -print("\nTorch XOR:", torch.logical_xor(bool_a, bool_b)) -print("SoftTorch XOR:", st.logical_xor(fuzzy_a, fuzzy_b)) - -print("\nTorch ALL:", torch.all(bool_a)) -print("SoftTorch ALL:", st.all(fuzzy_a)) - -print("\nTorch ANY:", torch.any(bool_a)) -print("SoftTorch ANY:", st.any(fuzzy_a)) - -# Selection operators -print("\nTorch Where:", torch.where(bool_a, x, y)) -print("SoftTorch Where:", st.where(fuzzy_a, x, y)) -``` -``` -Torch AND: tensor([False, False, False, True]) -SoftTorch AND: tensor([0.0700, 0.0600, 0.0800, 0.9000]) - -Torch OR: tensor([ True, False, True, True]) -SoftTorch OR: tensor([0.7300, 0.4400, 0.8200, 1.0000]) - -Torch NOT: tensor([ True, True, False, False]) -SoftTorch NOT: tensor([0.9000, 0.8000, 0.2000, 0.0000]) +x = torch.tensor([[0.2, 0.8], [0.5, 0.3], [0.9, 0.1], [0.4, 0.7], [0.1, 0.4], [0.2, 0.7], [0.4, 0.1], [0.4, 0.7], + [0.7, 0.29], [0.3, 0.3], [0.61, 0.25], [0.4, 0.6], [0.0, 0.1], [0.5, 0.3], [0.4, 0.9], [0.1, 0.57]]) +labels = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0]) + +@st.st +def rule_loss(params, x, labels, mode="smooth"): + lo, hi = params[0], params[1] + above = st.greater(x, lo, mode=mode) + below = st.less(x, hi, mode=mode) + in_range = st.logical_and(above, below) + preds = st.any(in_range, dim=-1) + return ((preds - labels) ** 2).sum() + +params = torch.tensor([0.0, 1.0], requires_grad=True) +hard_loss = rule_loss(params, x, labels, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, params)[0] if hard_loss.requires_grad else torch.zeros_like(params)) +soft_loss = rule_loss(params, x, labels, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, params)[0]) + +params = torch.tensor([0.0, 1.0]) +for _ in range(20): + params.requires_grad_(True) + loss = rule_loss(params, x, labels) + p_grad = torch.autograd.grad(loss, params)[0] + params = (params - 0.01 * p_grad).detach() +print("Learned [lo, hi]:", params) +``` +``` +Hard grad: tensor([0., 0.]) +Soft grad: tensor([-4.2777, 1.4152]) +Learned [lo, hi]: tensor([0.2925, 0.5999]) +``` + +Optimization trajectories -Torch XOR: tensor([ True, False, True, False]) -SoftTorch XOR: tensor([0.6411, 0.3464, 0.7256, 0.1000]) - -Torch ALL: tensor(False) -SoftTorch ALL: tensor(0.0160) +## Citation -Torch ANY: tensor(True) -SoftTorch ANY: tensor(1.) +If this library helped your academic work, please consider citing: ([arXiv link](https://arxiv.org/abs/2603.08824)) -Torch Where: tensor([ 0.2000, -0.5000, 0.3000, 1.0000]) -SoftTorch Where: tensor([ 0.1600, -0.6000, 0.3400, 1.0000]) +```bibtex +@article{paulus2026softjax, + title={{SoftJAX} \& {SoftTorch}: Empowering Automatic Differentiation Libraries with Informative Gradients}, + author={Paulus, Anselm and Geist, A.\ Ren\'e and Musil, V\'it and Hoffmann, Sebastian and Beker, Onur and Martius, Georg}, + journal={arXiv preprint}, + year={2026}, + eprint={2603.08824} +} ``` -```python -# Straight-through operators: Use hard function on forward and soft on backward -print("Straight-through ReLU:", st.relu_st(x)) -print("Straight-through sort:", st.sort_st(x).values) -print("Straight-through argtopk:", st.topk_st(x, k=3).indices) -print("Straight-through greater:", st.greater_st(x, y)) -# And many more... -``` -``` -Straight-through ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000]) -Straight-through sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -Straight-through argtopk: tensor([[0., 0., 0., 1.], - [0., 0., 1., 0.], - [1., 0., 0., 0.]]) -Straight-through greater: tensor([0., 0., 0., 1.]) -``` - -The outputs were generated with `docs/quick_example.py`. - - -## Citation - ---8<-- ".citation.md" +(Also consider starring the project [on GitHub](https://github.com/a-paulus/softtorch)) Special thanks and credit go to [Patrick Kidger](https://kidger.site) for the awesome [JAX repositories](https://github.com/patrick-kidger) that served as the basis for the documentation of this project. @@ -388,7 +200,7 @@ Have a look at the [All of SoftTorch](./all-of-softtorch.ipynb) page. ## Feedback -This project is still relatively young, if you have any suggestions for improvement or other feedback, please [reach out](mailto:paulus.anselm@gmail.com) or raise a GitHub issue! +If you have any suggestions for improvement or other feedback, please [reach out](mailto:paulus.anselm@gmail.com) or raise a GitHub issue! ## See also diff --git a/mkdocs.yml b/mkdocs.yml index cd8dca4..2de3530 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -77,7 +77,6 @@ plugins: - "_overrides" - "_static/README.md" - "examples/.ipynb_checkpoints" - - "examples/" - mkdocs-jupyter: include_requirejs: false custom_mathjax_url: "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/latest.js?config=TeX-AMS_CHTML-full,Safe" @@ -130,7 +129,7 @@ nav: - All of Softtorch: 'all-of-softtorch.ipynb' - Examples: - 'plots.ipynb' - - 'manifold_points.ipynb' + - 'examples/manifold_points.ipynb' - API: - 'api/softtorch_operators.md' - 'api/straight_through.md' diff --git a/src/softtorch/functions.py b/src/softtorch/functions.py index 110a1ec..71981a1 100644 --- a/src/softtorch/functions.py +++ b/src/softtorch/functions.py @@ -1440,8 +1440,10 @@ def topk( ot_kwargs=ot_kwargs, ) # (..., k, ..., [n]) if not gated_grad: - soft_index = soft_index.detach() - values = take_along_dim(x, soft_index, dim=dim) # (..., k, ...) + soft_index_tmp = soft_index.detach() + else: + soft_index_tmp = soft_index + values = take_along_dim(x, soft_index_tmp, dim=dim) # (..., k, ...) return torch.return_types.topk((values, soft_index))