diff --git a/README.md b/README.md index b4d9121..f9d9e1d 100644 --- a/README.md +++ b/README.md @@ -6,15 +6,16 @@
-# SoftTorch +# Soft differentiable programming in PyTorch [](https://pypi.org/project/softtorch/) [](https://pypi.org/project/softtorch/) [](https://github.com/a-paulus/softtorch/blob/main/LICENSE) +[](https://arxiv.org/abs/2603.08824) Looking for JAX? See [SoftJAX](https://github.com/a-paulus/softjax). -## In a nutshell +## What is SoftTorch? SoftTorch provides soft differentiable drop-in replacements for traditionally non-differentiable functions in [PyTorch](https://pytorch.org), including @@ -28,8 +29,7 @@ All operators offer multiple modes (controlling smoothness or boundedness of the All operators also support straight-through estimation, using the non-differentiable function in the forward pass and the soft relaxation in the backward pass. -SoftTorch functions are drop-in replacements for their non-differentiable PyTorch counterparts. -Special care is needed for functions operating on indices, as we relax discrete indices into distributions over indices, which modifies the shape of returned/accepted values. +*Note, while SoftTorch is designed to provide direct drop-in replacements for PyTorch's operators, soft axis-wise operators return a probability distribution over indices (instead of an index), effectively changing the shape of the function's output.* ## Installation @@ -44,351 +44,152 @@ pip install softtorch Available at https://a-paulus.github.io/softtorch/. -## Quick example -```python -import torch -import softtorch as st - -x = torch.tensor([-0.2, -1.0, 0.3, 1.0]) - -# Elementwise functions -print("\nTorch absolute:", torch.abs(x)) -print("SoftTorch absolute (hard mode):", st.abs(x, mode="hard")) -print("SoftTorch absolute (soft mode):", st.abs(x)) - -print("\nTorch clamp:", torch.clamp(x, -0.5, 0.5)) -print("SoftTorch clamp (hard mode):", st.clamp(x, -0.5, 0.5, mode="hard")) -print("SoftTorch clamp (soft mode):", st.clamp(x, -0.5, 0.5)) - -print("\nTorch heaviside:", torch.heaviside(x, torch.tensor(0.5))) -print("SoftTorch heaviside (hard mode):", st.heaviside(x, mode="hard")) -print("SoftTorch heaviside (soft mode):", st.heaviside(x)) - -print("\nTorch ReLU:", torch.nn.functional.relu(x)) -print("SoftTorch ReLU (hard mode):", st.relu(x, mode="hard")) -print("SoftTorch ReLU (soft mode):", st.relu(x)) - -print("\nTorch round:", torch.round(x)) -print("SoftTorch round (hard mode):", st.round(x, mode="hard")) -print("SoftTorch round (soft mode):", st.round(x)) - -print("\nTorch sign:", torch.sign(x)) -print("SoftTorch sign (hard mode):", st.sign(x, mode="hard")) -print("SoftTorch sign (soft mode):", st.sign(x)) -``` -``` -Torch absolute: tensor([0.2000, 1.0000, 0.3000, 1.0000]) -SoftTorch absolute (hard mode): tensor([0.2000, 1.0000, 0.3000, 1.0000]) -SoftTorch absolute (soft mode): tensor([0.1523, 0.9999, 0.2715, 0.9999]) - -Torch clamp: tensor([-0.2000, -0.5000, 0.3000, 0.5000]) -SoftTorch clamp (hard mode): tensor([-0.2000, -0.5000, 0.3000, 0.5000]) -SoftTorch clamp (soft mode): tensor([-0.1952, -0.4993, 0.2873, 0.4993]) - -Torch heaviside: tensor([0., 0., 1., 1.]) -SoftTorch heaviside (hard mode): tensor([0., 0., 1., 1.]) -SoftTorch heaviside (soft mode): tensor([0.1192, 0.0000, 0.9526, 1.0000]) - -Torch ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000]) -SoftTorch ReLU (hard mode): tensor([0.0000, 0.0000, 0.3000, 1.0000]) -SoftTorch ReLU (soft mode): tensor([0.0127, 0.0000, 0.3049, 1.0000]) - -Torch round: tensor([-0., -1., 0., 1.]) -SoftTorch round (hard mode): tensor([-0., -1., 0., 1.]) -SoftTorch round (soft mode): tensor([-0.0465, -1.0000, 0.1189, 1.0000]) - -Torch sign: tensor([-1., -1., 1., 1.]) -SoftTorch sign (hard mode): tensor([-1., -1., 1., 1.]) -SoftTorch sign (soft mode): tensor([-0.7616, -0.9999, 0.9051, 0.9999]) -``` +## Quick examples +**Robust median regression:** +Minimize the median absolute residual to be robust to outliers. ```python -# Tensor-valued operators -print("\nTorch max:", torch.max(x)) -print("SoftTorch max (hard mode):", st.max(x, mode="hard")) -print("SoftTorch max (soft mode):", st.max(x)) - -print("\nTorch min:", torch.min(x)) -print("SoftTorch min (hard mode):", st.min(x, mode="hard")) -print("SoftTorch min (soft mode):", st.min(x)) - -print("\nTorch sort:", torch.sort(x).values) -print("SoftTorch sort (hard mode):", st.sort(x, mode="hard").values) -print("SoftTorch sort (soft mode):", st.sort(x).values) - -print("\nTorch quantile:", torch.quantile(x, q=0.2)) -print("SoftTorch quantile (hard mode):", st.quantile(x, q=0.2, mode="hard")) -print("SoftTorch quantile (soft mode):", st.quantile(x, q=0.2)) - -print("\nTorch median:", torch.median(x)) -print("SoftTorch median (hard mode):", st.median(x, mode="hard")) -print("SoftTorch median (soft mode):", st.median(x)) - -print("\nTorch topk:", torch.topk(x, k=3).values) -print("SoftTorch topk (hard mode):", st.topk(x, k=3, mode="hard").values) -print("SoftTorch topk (soft mode):", st.topk(x, k=3).values) - -print("\nTorch rank:", torch.argsort(torch.argsort(x))) -print("SoftTorch rank (hard mode):", st.rank(x, mode="hard", descending=False)) -print("SoftTorch rank (soft mode):", st.rank(x, descending=False)) -``` -``` -Torch max: tensor(1.) -SoftTorch max (hard mode): tensor(1.) -SoftTorch max (soft mode): tensor(0.8874) - -Torch min: tensor(-1.) -SoftTorch min (hard mode): tensor(-1.) -SoftTorch min (soft mode): tensor(-0.8996) - -Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (hard mode): tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (soft mode): tensor([-0.8792, -0.1641, 0.2767, 0.8738]) +import torch, softtorch as st -Torch quantile: tensor(-0.5200) -SoftTorch quantile (hard mode): tensor(-0.5200) -SoftTorch quantile (soft mode): tensor(-0.4501) +torch.manual_seed(0) +X = torch.randn(20, 3) +w_true = torch.tensor([1.0, -2.0, 0.5]) +y = X @ w_true +y[0] = 1e6 # inject outlier -Torch median: tensor(-0.2000) -SoftTorch median (hard mode): tensor(-0.2000) -SoftTorch median (soft mode): tensor(-0.1641) +def median_regression_loss(w, X, y, mode="smooth"): + residuals = y - X @ w + return st.median(st.abs(residuals, mode=mode), mode=mode) -Torch topk: tensor([ 1.0000, 0.3000, -0.2000]) -SoftTorch topk (hard mode): tensor([ 1.0000, 0.3000, -0.2000]) -SoftTorch topk (soft mode): tensor([ 0.8738, 0.2767, -0.1641]) +w = torch.zeros(3, requires_grad=True) +hard_loss = median_regression_loss(w, X, y, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, w)[0]) +soft_loss = median_regression_loss(w, X, y, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, w)[0]) -Torch rank: tensor([1, 0, 2, 3]) -SoftTorch rank (hard mode): tensor([2., 1., 3., 4.]) -SoftTorch rank (soft mode): tensor([1.9950, 1.0548, 3.0239, 3.9228]) -``` - -```python -# Sort: sweep over methods -print("\nTorch sort:", torch.sort(x).values) -print("SoftTorch sort (softsort):", st.sort(x, method="softsort", softness=0.1).values) -print("SoftTorch sort (neuralsort):", st.sort(x, method="neuralsort", softness=0.1).values) -print("SoftTorch sort (fast_soft_sort):", st.sort(x, method="fast_soft_sort", softness=2.0).values) -print("SoftTorch sort (ot):", st.sort(x, method="ot", softness=0.1).values) -print("SoftTorch sort (sorting_network):", st.sort(x, method="sorting_network", softness=0.1).values) - -# Sort: sweep over modes -print("\nTorch sort:", torch.sort(x).values) -for mode in ["hard", "smooth", "c0", "c1", "c2"]: - print(f"SoftTorch sort ({mode}):", st.sort(x, softness=0.5, mode=mode).values) +w = torch.zeros(3) +for _ in range(50): + w.requires_grad_(True) + loss = median_regression_loss(w, X, y) + g = torch.autograd.grad(loss, w)[0] + w = (w - 0.1 * g).detach() +print("Learned w:", w, " (true:", w_true, ")") ``` ``` -Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (softsort): tensor([-0.8996, -0.1705, 0.2847, 0.8874]) -SoftTorch sort (neuralsort): tensor([-0.8792, -0.1641, 0.2767, 0.8738]) -SoftTorch sort (fast_soft_sort): tensor([-0.7462, -0.1971, 0.2938, 0.8569]) -SoftTorch sort (ot): tensor([-0.7324, -0.2396, 0.3286, 0.7434]) -SoftTorch sort (sorting_network): tensor([-0.7999, -0.2672, 0.3847, 0.7863]) - -Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (hard): tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -SoftTorch sort (smooth): tensor([-0.6057, -0.1997, 0.2729, 0.6281]) -SoftTorch sort (c0): tensor([-1.0000, -0.6313, 0.6525, 0.9824]) -SoftTorch sort (c1): tensor([-0.9982, -0.5432, 0.5814, 0.9837]) -SoftTorch sort (c2): tensor([-0.9978, -0.4905, 0.5425, 0.9903]) +Hard grad: tensor([ 0.2103, 0.1772, -0.8305]) +Soft grad: tensor([ 0.0731, 0.7100, -0.2970]) +Learned w: tensor([ 1.0000, -2.0000, 0.5000]) (true: tensor([ 1.0000, -2.0000, 0.5000]) ) ``` +**Top-k feature selection:** +Discover which features of a trained model are important. ```python -# Operators returning indices -print("\nTorch argmax:", torch.argmax(x)) -print("SoftTorch argmax (hard mode):", st.argmax(x, mode="hard")) -print("SoftTorch argmax (soft mode):", st.argmax(x)) - -print("\nTorch argmin:", torch.argmin(x)) -print("SoftTorch argmin (hard mode):", st.argmin(x, mode="hard")) -print("SoftTorch argmin (soft mode):", st.argmin(x)) - -print("\nTorch argquantile:", "Not implemented in standard PyTorch") -print("SoftTorch argquantile (hard mode):", st.argquantile(x, q=0.2, mode="hard")) -print("SoftTorch argquantile (soft mode):", st.argquantile(x, q=0.2)) - -print("\nTorch argmedian:", torch.median(x, dim=0).indices) -print("SoftTorch argmedian (hard mode):", st.median(x, mode="hard", dim=0).indices) -print("SoftTorch argmedian (soft mode):", st.median(x, dim=0).indices) - -print("\nTorch argsort:", torch.argsort(x)) -print("SoftTorch argsort (hard mode):", st.argsort(x, mode="hard")) -print("SoftTorch argsort (soft mode):", st.argsort(x)) - -print("\nTorch argtopk:", torch.topk(x, k=3).indices) -print("SoftTorch argtopk (hard mode):", st.topk(x, k=3, mode="hard").indices) -print("SoftTorch argtopk (soft mode):", st.topk(x, k=3).indices) -``` -``` -Torch argmax: tensor(3) -SoftTorch argmax (hard mode): tensor([0., 0., 0., 1.]) -SoftTorch argmax (soft mode): tensor([0.0215, 0.0022, 0.1176, 0.8586]) - -Torch argmin: tensor(1) -SoftTorch argmin (hard mode): tensor([0., 1., 0., 0.]) -SoftTorch argmin (soft mode): tensor([0.0922, 0.8885, 0.0169, 0.0023]) - -Torch argquantile: Not implemented in standard PyTorch -SoftTorch argquantile (hard mode): tensor([0.6000, 0.4000, 0.0000, 0.0000]) -SoftTorch argquantile (soft mode): tensor([0.5403, 0.3693, 0.0902, 0.0001]) - -Torch argmedian: tensor(0) -SoftTorch argmedian (hard mode): tensor([1., 0., 0., 0.]) -SoftTorch argmedian (soft mode): tensor([0.8009, 0.0491, 0.1498, 0.0002]) - -Torch argsort: tensor([1, 0, 2, 3]) -SoftTorch argsort (hard mode): tensor([[0., 1., 0., 0.], - [1., 0., 0., 0.], - [0., 0., 1., 0.], - [0., 0., 0., 1.]]) -SoftTorch argsort (soft mode): tensor([[0.1494, 0.8496, 0.0009, 0.0000], - [0.8009, 0.0491, 0.1498, 0.0002], - [0.1418, 0.0001, 0.7899, 0.0681], - [0.0011, 0.0000, 0.1784, 0.8205]]) - -Torch argtopk: tensor([3, 2, 0]) -SoftTorch argtopk (hard mode): tensor([[0., 0., 0., 1.], - [0., 0., 1., 0.], - [1., 0., 0., 0.]]) -SoftTorch argtopk (soft mode): tensor([[0.0011, 0.0000, 0.1784, 0.8205], - [0.1418, 0.0001, 0.7899, 0.0681], - [0.8009, 0.0491, 0.1498, 0.0002]]) -``` - +n_features, k = 10, 3 +torch.manual_seed(42) +X = torch.randn(100, n_features) +w_model = torch.tensor([0, 2.0, 0, -1.5, 0, 0, 0, 5.0, 0, 0]) +y = X @ w_model + 0.1 * torch.randn(100) + +def feature_selection_loss(g, X, y, w_model, mode="smooth"): + _, soft_idx = st.topk(g, k=k, mode=mode, gated_grad=False) + mask = soft_idx.sum(dim=0) + y_pred = (X * mask) @ w_model + return torch.mean(st.abs(y_pred - y)) + +g = torch.zeros(n_features, requires_grad=True) +hard_loss = feature_selection_loss(g, X, y, w_model, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, g)[0] if hard_loss.requires_grad else torch.zeros_like(g)) +soft_loss = feature_selection_loss(g, X, y, w_model, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, g)[0]) + +g = torch.zeros(n_features) +for _ in range(5): + g.requires_grad_(True) + loss = feature_selection_loss(g, X, y, w_model) + g_grad = torch.autograd.grad(loss, g)[0] + g = (g - 0.001 * g_grad).detach() +print("Selected features:", torch.topk(g, k=k).indices) +``` +``` +Hard grad: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) +Soft grad: tensor([ 2359.3386, 62.9980, 2359.3386, -890.2852, 2359.3386, + 2359.3386, 2359.3386, -15688.0829, 2359.3386, 2359.3386]) +Selected features: tensor([7, 3, 1]) +``` + +**Differentiable threshold filtering:** +Learn a threshold that gates inputs. ```python -y = torch.tensor([0.2, -0.5, 0.5, -1.0]) - -# Comparison operators -print("\nTorch greater:", torch.greater(x, y)) -print("SoftTorch greater (hard mode):", st.greater(x, y, mode="hard")) -print("SoftTorch greater (soft mode):", st.greater(x, y)) - -print("\nTorch greater equal:", torch.greater_equal(x, y)) -print("SoftTorch greater equal (hard mode):", st.greater_equal(x, y, mode="hard")) -print("SoftTorch greater equal (soft mode):", st.greater_equal(x, y)) - -print("\nTorch less:", torch.less(x, y)) -print("SoftTorch less (hard mode):", st.less(x, y, mode="hard")) -print("SoftTorch less (soft mode):", st.less(x, y)) +x = torch.tensor([0.2, 0.8, 0.5, 1.2, 0.1]) +target_sum = 2.0 # sum of values above threshold = 2.0 (i.e. 0.8 + 1.2) -print("\nTorch less equal:", torch.less_equal(x, y)) -print("SoftTorch less equal (hard mode):", st.less_equal(x, y, mode="hard")) -print("SoftTorch less equal (soft mode):", st.less_equal(x, y)) +def filter_loss(t, x, target, mode="smooth"): + mask = st.greater(x, t, mode=mode) + return (torch.sum(mask * x) - target) ** 2 -print("\nTorch eq:", torch.eq(x, y)) -print("SoftTorch eq (hard mode):", st.eq(x, y, mode="hard")) -print("SoftTorch eq (soft mode):", st.eq(x, y)) +t = torch.tensor(0.0, requires_grad=True) +hard_loss = filter_loss(t, x, target_sum, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, t)[0] if hard_loss.requires_grad else torch.zeros_like(t)) +soft_loss = filter_loss(t, x, target_sum, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, t)[0]) -print("\nTorch not equal:", torch.not_equal(x, y)) -print("SoftTorch not equal (hard mode):", st.not_equal(x, y, mode="hard")) -print("SoftTorch not equal (soft mode):", st.not_equal(x, y)) - -print("\nTorch isclose:", torch.isclose(x, y)) -print("SoftTorch isclose (hard mode):", st.isclose(x, y, mode="hard")) -print("SoftTorch isclose (soft mode):", st.isclose(x, y)) +t = torch.tensor(0.0) +for _ in range(20): + t.requires_grad_(True) + loss = filter_loss(t, x, target_sum) + t_grad = torch.autograd.grad(loss, t)[0] + t = (t - 0.1 * t_grad).detach() +print("Learned threshold:", t) ``` ``` -Torch greater: tensor([False, False, False, True]) -SoftTorch greater (hard mode): tensor([0., 0., 0., 1.]) -SoftTorch greater (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000]) - -Torch greater equal: tensor([False, False, False, True]) -SoftTorch greater equal (hard mode): tensor([0., 0., 0., 1.]) -SoftTorch greater equal (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000]) - -Torch less: tensor([ True, True, True, False]) -SoftTorch less (hard mode): tensor([1., 1., 1., 0.]) -SoftTorch less (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000]) - -Torch less equal: tensor([ True, True, True, False]) -SoftTorch less equal (hard mode): tensor([1., 1., 1., 0.]) -SoftTorch less equal (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000]) - -Torch eq: tensor([False, False, False, False]) -SoftTorch eq (hard mode): tensor([0., 0., 0., 0.]) -SoftTorch eq (soft mode): tensor([0.0414, 0.0143, 0.3580, 0.0000]) - -Torch not equal: tensor([True, True, True, True]) -SoftTorch not equal (hard mode): tensor([1., 1., 1., 1.]) -SoftTorch not equal (soft mode): tensor([0.9586, 0.9857, 0.6420, 1.0000]) - -Torch isclose: tensor([False, False, False, False]) -SoftTorch isclose (hard mode): tensor([0., 0., 0., 0.]) -SoftTorch isclose (soft mode): tensor([0.0414, 0.0143, 0.3580, 0.0000]) +Hard grad: tensor(0.) +Soft grad: tensor(-0.6600) +Learned threshold: tensor(0.6211) ``` +**Rule-based classifier:** +Learn decision boundaries `[lo, hi]` for a rule using soft logic and straight-through estimation. The rule is true if any element of a feature is inside `[lo, hi]`. ```python -# Logical operators -fuzzy_a = torch.tensor([0.1, 0.2, 0.8, 1.0]) -fuzzy_b = torch.tensor([0.7, 0.3, 0.1, 0.9]) -bool_a = fuzzy_a >= 0.5 -bool_b = fuzzy_b >= 0.5 - -print("\nTorch AND:", torch.logical_and(bool_a, bool_b)) -print("SoftTorch AND:", st.logical_and(fuzzy_a, fuzzy_b)) - -print("\nTorch OR:", torch.logical_or(bool_a, bool_b)) -print("SoftTorch OR:", st.logical_or(fuzzy_a, fuzzy_b)) +x = torch.tensor([[0.2, 0.8], [0.5, 0.3], [0.9, 0.1], [0.4, 0.7], [0.1, 0.4], [0.2, 0.7], [0.4, 0.1], [0.4, 0.7], + [0.7, 0.29], [0.3, 0.3], [0.61, 0.25], [0.4, 0.6], [0.0, 0.1], [0.5, 0.3], [0.4, 0.9], [0.1, 0.57]]) +labels = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0]) -print("\nTorch NOT:", torch.logical_not(bool_a)) -print("SoftTorch NOT:", st.logical_not(fuzzy_a)) +@st.st +def rule_loss(params, x, labels, mode="smooth"): + lo, hi = params[0], params[1] + above = st.greater(x, lo, mode=mode) + below = st.less(x, hi, mode=mode) + in_range = st.logical_and(above, below) + preds = st.any(in_range, dim=-1) + return ((preds - labels) ** 2).sum() -print("\nTorch XOR:", torch.logical_xor(bool_a, bool_b)) -print("SoftTorch XOR:", st.logical_xor(fuzzy_a, fuzzy_b)) +params = torch.tensor([0.0, 1.0], requires_grad=True) +hard_loss = rule_loss(params, x, labels, mode="hard") +print("Hard grad:", torch.autograd.grad(hard_loss, params)[0] if hard_loss.requires_grad else torch.zeros_like(params)) +soft_loss = rule_loss(params, x, labels, mode="smooth") +print("Soft grad:", torch.autograd.grad(soft_loss, params)[0]) -print("\nTorch ALL:", torch.all(bool_a)) -print("SoftTorch ALL:", st.all(fuzzy_a)) - -print("\nTorch ANY:", torch.any(bool_a)) -print("SoftTorch ANY:", st.any(fuzzy_a)) - -# Selection operators -print("\nTorch Where:", torch.where(bool_a, x, y)) -print("SoftTorch Where:", st.where(fuzzy_a, x, y)) +params = torch.tensor([0.0, 1.0]) +for _ in range(20): + params.requires_grad_(True) + loss = rule_loss(params, x, labels) + p_grad = torch.autograd.grad(loss, params)[0] + params = (params - 0.01 * p_grad).detach() +print("Learned [lo, hi]:", params) ``` ``` -Torch AND: tensor([False, False, False, True]) -SoftTorch AND: tensor([0.0700, 0.0600, 0.0800, 0.9000]) - -Torch OR: tensor([ True, False, True, True]) -SoftTorch OR: tensor([0.7300, 0.4400, 0.8200, 1.0000]) - -Torch NOT: tensor([ True, True, False, False]) -SoftTorch NOT: tensor([0.9000, 0.8000, 0.2000, 0.0000]) - -Torch XOR: tensor([ True, False, True, False]) -SoftTorch XOR: tensor([0.6411, 0.3464, 0.7256, 0.1000]) - -Torch ALL: tensor(False) -SoftTorch ALL: tensor(0.0160) - -Torch ANY: tensor(True) -SoftTorch ANY: tensor(1.) - -Torch Where: tensor([ 0.2000, -0.5000, 0.3000, 1.0000]) -SoftTorch Where: tensor([ 0.1600, -0.6000, 0.3400, 1.0000]) +Hard grad: tensor([0., 0.]) +Soft grad: tensor([-4.2777, 1.4152]) +Learned [lo, hi]: tensor([0.2925, 0.5999]) ``` -```python -# Straight-through operators: Use hard function on forward and soft on backward -print("Straight-through ReLU:", st.relu_st(x)) -print("Straight-through sort:", st.sort_st(x).values) -print("Straight-through argtopk:", st.topk_st(x, k=3).indices) -print("Straight-through greater:", st.greater_st(x, y)) -# And many more... -``` -``` -Straight-through ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000]) -Straight-through sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000]) -Straight-through argtopk: tensor([[0., 0., 0., 1.], - [0., 0., 1., 0.], - [1., 0., 0., 0.]]) -Straight-through greater: tensor([0., 0., 0., 1.]) -``` +
-# SoftTorch
+# Soft differentiable programming in PyTorch
-## In a nutshell
+Looking for JAX? See [SoftJAX](https://github.com/a-paulus/softjax).
+
+
+## What is SoftTorch?
SoftTorch provides soft differentiable drop-in replacements for traditionally non-differentiable functions in [PyTorch](https://pytorch.org), including
@@ -19,8 +22,7 @@ All operators offer multiple modes (controlling smoothness or boundedness of the
All operators also support straight-through estimation, using the non-differentiable function in the forward pass and the soft relaxation in the backward pass.
-SoftTorch functions are drop-in replacements for their non-differentiable PyTorch counterparts.
-Special care is needed for functions operating on indices, as we relax discrete indices into distributions over indices, which modifies the shape of returned/accepted values.
+*Note, while SoftTorch is designed to provide direct drop-in replacements for PyTorch's operators, soft axis-wise operators return a probability distribution over indices (instead of an index), effectively changing the shape of the function's output.*
## Installation
@@ -30,353 +32,163 @@ pip install softtorch
```
-## Quick example
-```python
-import torch
-import softtorch as st
-
-x = torch.tensor([-0.2, -1.0, 0.3, 1.0])
-
-# Elementwise functions
-print("\nTorch absolute:", torch.abs(x))
-print("SoftTorch absolute (hard mode):", st.abs(x, mode="hard"))
-print("SoftTorch absolute (soft mode):", st.abs(x))
-
-print("\nTorch clamp:", torch.clamp(x, -0.5, 0.5))
-print("SoftTorch clamp (hard mode):", st.clamp(x, -0.5, 0.5, mode="hard"))
-print("SoftTorch clamp (soft mode):", st.clamp(x, -0.5, 0.5))
-
-print("\nTorch heaviside:", torch.heaviside(x, torch.tensor(0.5)))
-print("SoftTorch heaviside (hard mode):", st.heaviside(x, mode="hard"))
-print("SoftTorch heaviside (soft mode):", st.heaviside(x))
-
-print("\nTorch ReLU:", torch.nn.functional.relu(x))
-print("SoftTorch ReLU (hard mode):", st.relu(x, mode="hard"))
-print("SoftTorch ReLU (soft mode):", st.relu(x))
-
-print("\nTorch round:", torch.round(x))
-print("SoftTorch round (hard mode):", st.round(x, mode="hard"))
-print("SoftTorch round (soft mode):", st.round(x))
-
-print("\nTorch sign:", torch.sign(x))
-print("SoftTorch sign (hard mode):", st.sign(x, mode="hard"))
-print("SoftTorch sign (soft mode):", st.sign(x))
-```
-```
-Torch absolute: tensor([0.2000, 1.0000, 0.3000, 1.0000])
-SoftTorch absolute (hard mode): tensor([0.2000, 1.0000, 0.3000, 1.0000])
-SoftTorch absolute (soft mode): tensor([0.1523, 0.9999, 0.2715, 0.9999])
-
-Torch clamp: tensor([-0.2000, -0.5000, 0.3000, 0.5000])
-SoftTorch clamp (hard mode): tensor([-0.2000, -0.5000, 0.3000, 0.5000])
-SoftTorch clamp (soft mode): tensor([-0.1952, -0.4993, 0.2873, 0.4993])
-
-Torch heaviside: tensor([0., 0., 1., 1.])
-SoftTorch heaviside (hard mode): tensor([0., 0., 1., 1.])
-SoftTorch heaviside (soft mode): tensor([0.1192, 0.0000, 0.9526, 1.0000])
-
-Torch ReLU: tensor([0.0000, 0.0000, 0.3000, 1.0000])
-SoftTorch ReLU (hard mode): tensor([0.0000, 0.0000, 0.3000, 1.0000])
-SoftTorch ReLU (soft mode): tensor([0.0127, 0.0000, 0.3049, 1.0000])
-
-Torch round: tensor([-0., -1., 0., 1.])
-SoftTorch round (hard mode): tensor([-0., -1., 0., 1.])
-SoftTorch round (soft mode): tensor([-0.0465, -1.0000, 0.1189, 1.0000])
-
-Torch sign: tensor([-1., -1., 1., 1.])
-SoftTorch sign (hard mode): tensor([-1., -1., 1., 1.])
-SoftTorch sign (soft mode): tensor([-0.7616, -0.9999, 0.9051, 0.9999])
-```
+## Quick examples
+**Robust median regression:**
+Minimize the median absolute residual to be robust to outliers.
```python
-# Tensor-valued operators
-print("\nTorch max:", torch.max(x))
-print("SoftTorch max (hard mode):", st.max(x, mode="hard"))
-print("SoftTorch max (soft mode):", st.max(x))
-
-print("\nTorch min:", torch.min(x))
-print("SoftTorch min (hard mode):", st.min(x, mode="hard"))
-print("SoftTorch min (soft mode):", st.min(x))
-
-print("\nTorch sort:", torch.sort(x).values)
-print("SoftTorch sort (hard mode):", st.sort(x, mode="hard").values)
-print("SoftTorch sort (soft mode):", st.sort(x).values)
-
-print("\nTorch quantile:", torch.quantile(x, q=0.2))
-print("SoftTorch quantile (hard mode):", st.quantile(x, q=0.2, mode="hard"))
-print("SoftTorch quantile (soft mode):", st.quantile(x, q=0.2))
-
-print("\nTorch median:", torch.median(x))
-print("SoftTorch median (hard mode):", st.median(x, mode="hard"))
-print("SoftTorch median (soft mode):", st.median(x))
-
-print("\nTorch topk:", torch.topk(x, k=3).values)
-print("SoftTorch topk (hard mode):", st.topk(x, k=3, mode="hard").values)
-print("SoftTorch topk (soft mode):", st.topk(x, k=3).values)
-
-print("\nTorch rank:", torch.argsort(torch.argsort(x)))
-print("SoftTorch rank (hard mode):", st.rank(x, mode="hard", descending=False))
-print("SoftTorch rank (soft mode):", st.rank(x, descending=False))
-```
-```
-Torch max: tensor(1.)
-SoftTorch max (hard mode): tensor(1.)
-SoftTorch max (soft mode): tensor(0.8874)
+import torch, softtorch as st
-Torch min: tensor(-1.)
-SoftTorch min (hard mode): tensor(-1.)
-SoftTorch min (soft mode): tensor(-0.8996)
+torch.manual_seed(0)
+X = torch.randn(20, 3)
+w_true = torch.tensor([1.0, -2.0, 0.5])
+y = X @ w_true
+y[0] = 1e6 # inject outlier
-Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000])
-SoftTorch sort (hard mode): tensor([-1.0000, -0.2000, 0.3000, 1.0000])
-SoftTorch sort (soft mode): tensor([-0.8792, -0.1641, 0.2767, 0.8738])
+def median_regression_loss(w, X, y, mode="smooth"):
+ residuals = y - X @ w
+ return st.median(st.abs(residuals, mode=mode), mode=mode)
-Torch quantile: tensor(-0.5200)
-SoftTorch quantile (hard mode): tensor(-0.5200)
-SoftTorch quantile (soft mode): tensor(-0.4501)
+w = torch.zeros(3, requires_grad=True)
+hard_loss = median_regression_loss(w, X, y, mode="hard")
+print("Hard grad:", torch.autograd.grad(hard_loss, w)[0])
+soft_loss = median_regression_loss(w, X, y, mode="smooth")
+print("Soft grad:", torch.autograd.grad(soft_loss, w)[0])
-Torch median: tensor(-0.2000)
-SoftTorch median (hard mode): tensor(-0.2000)
-SoftTorch median (soft mode): tensor(-0.1641)
-
-Torch topk: tensor([ 1.0000, 0.3000, -0.2000])
-SoftTorch topk (hard mode): tensor([ 1.0000, 0.3000, -0.2000])
-SoftTorch topk (soft mode): tensor([ 0.8738, 0.2767, -0.1641])
-
-Torch rank: tensor([1, 0, 2, 3])
-SoftTorch rank (hard mode): tensor([2., 1., 3., 4.])
-SoftTorch rank (soft mode): tensor([1.9950, 1.0548, 3.0239, 3.9228])
+w = torch.zeros(3)
+for _ in range(50):
+ w.requires_grad_(True)
+ loss = median_regression_loss(w, X, y)
+ g = torch.autograd.grad(loss, w)[0]
+ w = (w - 0.1 * g).detach()
+print("Learned w:", w, " (true:", w_true, ")")
```
-
-```python
-# Sort: sweep over methods
-print("\nTorch sort:", torch.sort(x).values)
-print("SoftTorch sort (softsort):", st.sort(x, method="softsort", softness=0.1).values)
-print("SoftTorch sort (neuralsort):", st.sort(x, method="neuralsort", softness=0.1).values)
-print("SoftTorch sort (fast_soft_sort):", st.sort(x, method="fast_soft_sort", softness=2.0).values)
-print("SoftTorch sort (ot):", st.sort(x, method="ot", softness=0.1).values)
-print("SoftTorch sort (sorting_network):", st.sort(x, method="sorting_network", softness=0.1).values)
-
-# Sort: sweep over modes
-print("\nTorch sort:", torch.sort(x).values)
-for mode in ["hard", "smooth", "c0", "c1", "c2"]:
- print(f"SoftTorch sort ({mode}):", st.sort(x, softness=0.5, mode=mode).values)
```
-```
-Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000])
-SoftTorch sort (softsort): tensor([-0.8996, -0.1705, 0.2847, 0.8874])
-SoftTorch sort (neuralsort): tensor([-0.8792, -0.1641, 0.2767, 0.8738])
-SoftTorch sort (fast_soft_sort): tensor([-0.7462, -0.1971, 0.2938, 0.8569])
-SoftTorch sort (ot): tensor([-0.7324, -0.2396, 0.3286, 0.7434])
-SoftTorch sort (sorting_network): tensor([-0.7999, -0.2672, 0.3847, 0.7863])
-
-Torch sort: tensor([-1.0000, -0.2000, 0.3000, 1.0000])
-SoftTorch sort (hard): tensor([-1.0000, -0.2000, 0.3000, 1.0000])
-SoftTorch sort (smooth): tensor([-0.6057, -0.1997, 0.2729, 0.6281])
-SoftTorch sort (c0): tensor([-1.0000, -0.6313, 0.6525, 0.9824])
-SoftTorch sort (c1): tensor([-0.9982, -0.5432, 0.5814, 0.9837])
-SoftTorch sort (c2): tensor([-0.9978, -0.4905, 0.5425, 0.9903])
+Hard grad: tensor([ 0.2103, 0.1772, -0.8305])
+Soft grad: tensor([ 0.0731, 0.7100, -0.2970])
+Learned w: tensor([ 1.0000, -2.0000, 0.5000]) (true: tensor([ 1.0000, -2.0000, 0.5000]) )
```
+**Top-k feature selection:**
+Discover which features of a trained model are important.
```python
-# Operators returning indices
-print("\nTorch argmax:", torch.argmax(x))
-print("SoftTorch argmax (hard mode):", st.argmax(x, mode="hard"))
-print("SoftTorch argmax (soft mode):", st.argmax(x))
-
-print("\nTorch argmin:", torch.argmin(x))
-print("SoftTorch argmin (hard mode):", st.argmin(x, mode="hard"))
-print("SoftTorch argmin (soft mode):", st.argmin(x))
-
-print("\nTorch argquantile:", "Not implemented in standard PyTorch")
-print("SoftTorch argquantile (hard mode):", st.argquantile(x, q=0.2, mode="hard"))
-print("SoftTorch argquantile (soft mode):", st.argquantile(x, q=0.2))
-
-print("\nTorch argmedian:", torch.median(x, dim=0).indices)
-print("SoftTorch argmedian (hard mode):", st.median(x, mode="hard", dim=0).indices)
-print("SoftTorch argmedian (soft mode):", st.median(x, dim=0).indices)
-
-print("\nTorch argsort:", torch.argsort(x))
-print("SoftTorch argsort (hard mode):", st.argsort(x, mode="hard"))
-print("SoftTorch argsort (soft mode):", st.argsort(x))
-
-print("\nTorch argtopk:", torch.topk(x, k=3).indices)
-print("SoftTorch argtopk (hard mode):", st.topk(x, k=3, mode="hard").indices)
-print("SoftTorch argtopk (soft mode):", st.topk(x, k=3).indices)
-```
-```
-Torch argmax: tensor(3)
-SoftTorch argmax (hard mode): tensor([0., 0., 0., 1.])
-SoftTorch argmax (soft mode): tensor([0.0215, 0.0022, 0.1176, 0.8586])
-
-Torch argmin: tensor(1)
-SoftTorch argmin (hard mode): tensor([0., 1., 0., 0.])
-SoftTorch argmin (soft mode): tensor([0.0922, 0.8885, 0.0169, 0.0023])
-
-Torch argquantile: Not implemented in standard PyTorch
-SoftTorch argquantile (hard mode): tensor([0.6000, 0.4000, 0.0000, 0.0000])
-SoftTorch argquantile (soft mode): tensor([0.5403, 0.3693, 0.0902, 0.0001])
-
-Torch argmedian: tensor(0)
-SoftTorch argmedian (hard mode): tensor([1., 0., 0., 0.])
-SoftTorch argmedian (soft mode): tensor([0.8009, 0.0491, 0.1498, 0.0002])
-
-Torch argsort: tensor([1, 0, 2, 3])
-SoftTorch argsort (hard mode): tensor([[0., 1., 0., 0.],
- [1., 0., 0., 0.],
- [0., 0., 1., 0.],
- [0., 0., 0., 1.]])
-SoftTorch argsort (soft mode): tensor([[0.1494, 0.8496, 0.0009, 0.0000],
- [0.8009, 0.0491, 0.1498, 0.0002],
- [0.1418, 0.0001, 0.7899, 0.0681],
- [0.0011, 0.0000, 0.1784, 0.8205]])
-
-Torch argtopk: tensor([3, 2, 0])
-SoftTorch argtopk (hard mode): tensor([[0., 0., 0., 1.],
- [0., 0., 1., 0.],
- [1., 0., 0., 0.]])
-SoftTorch argtopk (soft mode): tensor([[0.0011, 0.0000, 0.1784, 0.8205],
- [0.1418, 0.0001, 0.7899, 0.0681],
- [0.8009, 0.0491, 0.1498, 0.0002]])
-```
-
+n_features, k = 10, 3
+torch.manual_seed(42)
+X = torch.randn(100, n_features)
+w_model = torch.tensor([0, 2.0, 0, -1.5, 0, 0, 0, 5.0, 0, 0])
+y = X @ w_model + 0.1 * torch.randn(100)
+
+def feature_selection_loss(g, X, y, w_model, mode="smooth"):
+ _, soft_idx = st.topk(g, k=k, mode=mode, gated_grad=False)
+ mask = soft_idx.sum(dim=0)
+ y_pred = (X * mask) @ w_model
+ return torch.mean(st.abs(y_pred - y))
+
+g = torch.zeros(n_features, requires_grad=True)
+hard_loss = feature_selection_loss(g, X, y, w_model, mode="hard")
+print("Hard grad:", torch.autograd.grad(hard_loss, g)[0] if hard_loss.requires_grad else torch.zeros_like(g))
+soft_loss = feature_selection_loss(g, X, y, w_model, mode="smooth")
+print("Soft grad:", torch.autograd.grad(soft_loss, g)[0])
+
+g = torch.zeros(n_features)
+for _ in range(5):
+ g.requires_grad_(True)
+ loss = feature_selection_loss(g, X, y, w_model)
+ g_grad = torch.autograd.grad(loss, g)[0]
+ g = (g - 0.001 * g_grad).detach()
+print("Selected features:", torch.topk(g, k=k).indices)
+```
+```
+Hard grad: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
+Soft grad: tensor([ 2359.3386, 62.9980, 2359.3386, -890.2852, 2359.3386,
+ 2359.3386, 2359.3386, -15688.0829, 2359.3386, 2359.3386])
+Selected features: tensor([7, 3, 1])
+```
+
+**Differentiable threshold filtering:**
+Learn a threshold that gates inputs.
```python
-y = torch.tensor([0.2, -0.5, 0.5, -1.0])
-
-# Comparison operators
-print("\nTorch greater:", torch.greater(x, y))
-print("SoftTorch greater (hard mode):", st.greater(x, y, mode="hard"))
-print("SoftTorch greater (soft mode):", st.greater(x, y))
-
-print("\nTorch greater equal:", torch.greater_equal(x, y))
-print("SoftTorch greater equal (hard mode):", st.greater_equal(x, y, mode="hard"))
-print("SoftTorch greater equal (soft mode):", st.greater_equal(x, y))
+x = torch.tensor([0.2, 0.8, 0.5, 1.2, 0.1])
+target_sum = 2.0 # sum of values above threshold = 2.0 (i.e. 0.8 + 1.2)
-print("\nTorch less:", torch.less(x, y))
-print("SoftTorch less (hard mode):", st.less(x, y, mode="hard"))
-print("SoftTorch less (soft mode):", st.less(x, y))
+def filter_loss(t, x, target, mode="smooth"):
+ mask = st.greater(x, t, mode=mode)
+ return (torch.sum(mask * x) - target) ** 2
-print("\nTorch less equal:", torch.less_equal(x, y))
-print("SoftTorch less equal (hard mode):", st.less_equal(x, y, mode="hard"))
-print("SoftTorch less equal (soft mode):", st.less_equal(x, y))
+t = torch.tensor(0.0, requires_grad=True)
+hard_loss = filter_loss(t, x, target_sum, mode="hard")
+print("Hard grad:", torch.autograd.grad(hard_loss, t)[0] if hard_loss.requires_grad else torch.zeros_like(t))
+soft_loss = filter_loss(t, x, target_sum, mode="smooth")
+print("Soft grad:", torch.autograd.grad(soft_loss, t)[0])
-print("\nTorch eq:", torch.eq(x, y))
-print("SoftTorch eq (hard mode):", st.eq(x, y, mode="hard"))
-print("SoftTorch eq (soft mode):", st.eq(x, y))
-
-print("\nTorch not equal:", torch.not_equal(x, y))
-print("SoftTorch not equal (hard mode):", st.not_equal(x, y, mode="hard"))
-print("SoftTorch not equal (soft mode):", st.not_equal(x, y))
-
-print("\nTorch isclose:", torch.isclose(x, y))
-print("SoftTorch isclose (hard mode):", st.isclose(x, y, mode="hard"))
-print("SoftTorch isclose (soft mode):", st.isclose(x, y))
+t = torch.tensor(0.0)
+for _ in range(20):
+ t.requires_grad_(True)
+ loss = filter_loss(t, x, target_sum)
+ t_grad = torch.autograd.grad(loss, t)[0]
+ t = (t - 0.1 * t_grad).detach()
+print("Learned threshold:", t)
```
```
-Torch greater: tensor([False, False, False, True])
-SoftTorch greater (hard mode): tensor([0., 0., 0., 1.])
-SoftTorch greater (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000])
-
-Torch greater equal: tensor([False, False, False, True])
-SoftTorch greater equal (hard mode): tensor([0., 0., 0., 1.])
-SoftTorch greater equal (soft mode): tensor([0.0180, 0.0067, 0.1192, 1.0000])
-
-Torch less: tensor([ True, True, True, False])
-SoftTorch less (hard mode): tensor([1., 1., 1., 0.])
-SoftTorch less (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000])
-
-Torch less equal: tensor([ True, True, True, False])
-SoftTorch less equal (hard mode): tensor([1., 1., 1., 0.])
-SoftTorch less equal (soft mode): tensor([0.9820, 0.9933, 0.8808, 0.0000])
-
-Torch eq: tensor([False, False, False, False])
-SoftTorch eq (hard mode): tensor([0., 0., 0., 0.])
-SoftTorch eq (soft mode): tensor([0.0414, 0.0143, 0.3580, 0.0000])
-
-Torch not equal: tensor([True, True, True, True])
-SoftTorch not equal (hard mode): tensor([1., 1., 1., 1.])
-SoftTorch not equal (soft mode): tensor([0.9586, 0.9857, 0.6420, 1.0000])
-
-Torch isclose: tensor([False, False, False, False])
-SoftTorch isclose (hard mode): tensor([0., 0., 0., 0.])
-SoftTorch isclose (soft mode): tensor([0.0414, 0.0143, 0.3580, 0.0000])
+Hard grad: tensor(0.)
+Soft grad: tensor(-0.6600)
+Learned threshold: tensor(0.6211)
```
+**Rule-based classifier:**
+Learn decision boundaries `[lo, hi]` for a rule using soft logic and straight-through estimation. The rule is true if any element of a feature is inside `[lo, hi]`.
```python
-# Logical operators
-fuzzy_a = torch.tensor([0.1, 0.2, 0.8, 1.0])
-fuzzy_b = torch.tensor([0.7, 0.3, 0.1, 0.9])
-bool_a = fuzzy_a >= 0.5
-bool_b = fuzzy_b >= 0.5
-
-print("\nTorch AND:", torch.logical_and(bool_a, bool_b))
-print("SoftTorch AND:", st.logical_and(fuzzy_a, fuzzy_b))
-
-print("\nTorch OR:", torch.logical_or(bool_a, bool_b))
-print("SoftTorch OR:", st.logical_or(fuzzy_a, fuzzy_b))
-
-print("\nTorch NOT:", torch.logical_not(bool_a))
-print("SoftTorch NOT:", st.logical_not(fuzzy_a))
-
-print("\nTorch XOR:", torch.logical_xor(bool_a, bool_b))
-print("SoftTorch XOR:", st.logical_xor(fuzzy_a, fuzzy_b))
-
-print("\nTorch ALL:", torch.all(bool_a))
-print("SoftTorch ALL:", st.all(fuzzy_a))
-
-print("\nTorch ANY:", torch.any(bool_a))
-print("SoftTorch ANY:", st.any(fuzzy_a))
-
-# Selection operators
-print("\nTorch Where:", torch.where(bool_a, x, y))
-print("SoftTorch Where:", st.where(fuzzy_a, x, y))
-```
-```
-Torch AND: tensor([False, False, False, True])
-SoftTorch AND: tensor([0.0700, 0.0600, 0.0800, 0.9000])
-
-Torch OR: tensor([ True, False, True, True])
-SoftTorch OR: tensor([0.7300, 0.4400, 0.8200, 1.0000])
-
-Torch NOT: tensor([ True, True, False, False])
-SoftTorch NOT: tensor([0.9000, 0.8000, 0.2000, 0.0000])
+x = torch.tensor([[0.2, 0.8], [0.5, 0.3], [0.9, 0.1], [0.4, 0.7], [0.1, 0.4], [0.2, 0.7], [0.4, 0.1], [0.4, 0.7],
+ [0.7, 0.29], [0.3, 0.3], [0.61, 0.25], [0.4, 0.6], [0.0, 0.1], [0.5, 0.3], [0.4, 0.9], [0.1, 0.57]])
+labels = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0,
+ 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0])
+
+@st.st
+def rule_loss(params, x, labels, mode="smooth"):
+ lo, hi = params[0], params[1]
+ above = st.greater(x, lo, mode=mode)
+ below = st.less(x, hi, mode=mode)
+ in_range = st.logical_and(above, below)
+ preds = st.any(in_range, dim=-1)
+ return ((preds - labels) ** 2).sum()
+
+params = torch.tensor([0.0, 1.0], requires_grad=True)
+hard_loss = rule_loss(params, x, labels, mode="hard")
+print("Hard grad:", torch.autograd.grad(hard_loss, params)[0] if hard_loss.requires_grad else torch.zeros_like(params))
+soft_loss = rule_loss(params, x, labels, mode="smooth")
+print("Soft grad:", torch.autograd.grad(soft_loss, params)[0])
+
+params = torch.tensor([0.0, 1.0])
+for _ in range(20):
+ params.requires_grad_(True)
+ loss = rule_loss(params, x, labels)
+ p_grad = torch.autograd.grad(loss, params)[0]
+ params = (params - 0.01 * p_grad).detach()
+print("Learned [lo, hi]:", params)
+```
+```
+Hard grad: tensor([0., 0.])
+Soft grad: tensor([-4.2777, 1.4152])
+Learned [lo, hi]: tensor([0.2925, 0.5999])
+```
+
+