Skip to content

Commit 9362e65

Browse files
committed
paper build
1 parent 8307a41 commit 9362e65

2 files changed

Lines changed: 112 additions & 79 deletions

File tree

paper/jats/paper.jats

Lines changed: 112 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -92,34 +92,31 @@ a Creative Commons Attribution 4.0 International License (CC BY
9292
<p>The <monospace>torchsparsegradutils</monospace> package provides
9393
differentiable sparse linear-algebra utilities for PyTorch
9494
(<xref alt="Paszke et al., 2019" rid="ref-pytorch" ref-type="bibr">Paszke
95-
et al., 2019</xref>) that preserve sparsity for returned gradients
96-
during backpropagation. While PyTorch directly supports sparse
97-
tensors, its default semantics treat sparse layouts as storage
98-
optimisations rather than a mathematical structure that results in
99-
optimising directly for that sparse subspace. Gradients resulting from
100-
PyTorch native functions are often dense and incompatible with
101-
end-to-end training of models that require fixed sparsity patterns
102-
(e.g., sparse covariance/precision structures).</p>
103-
<p>To address this limitation, we introduce
104-
<monospace>torchsparsegradutils</monospace>. Key features include: (1)
105-
memory-efficient sparse-dense matrix multiplication with sparse
106-
gradient preservation, (2) sparse triangular and generic linear system
107-
solvers, enabling sparse gradients during backpropagation, and
108-
multiple algorithmic backends (BICGSTAB, CG, LSMR, MINRES), (3)
109-
cross-platform sparse solver wrappers for CuPy
95+
et al., 2019</xref>) that preserve sparsity in returned gradients
96+
during backpropagation. While PyTorch supports sparse tensors, its
97+
default dense-equivalent backward semantics can densify gradients and
98+
make it difficult to optimise models with fixed sparsity patterns,
99+
such as sparse covariance or precision parameterisations.</p>
100+
<p>The package provides sparse-dense matrix multiplication with
101+
sparse-gradient preservation, sparse triangular and generic linear
102+
system solvers (including BICGSTAB, CG, LSMR, and MINRES backends),
103+
optional CuPy
110104
(<xref alt="Okuta et al., 2017" rid="ref-cupy" ref-type="bibr">Okuta
111105
et al., 2017</xref>) and JAX
112106
(<xref alt="Bradbury et al., 2018" rid="ref-jax" ref-type="bibr">Bradbury
113-
et al., 2018</xref>), (4) sparse multivariate normal distributions
114-
with <inline-formula><alternatives>
107+
et al., 2018</xref>) solver wrappers, sparse multivariate normal
108+
distributions with <inline-formula><alternatives>
115109
<tex-math><![CDATA[\boldsymbol{L}\boldsymbol{L}^T]]></tex-math>
116110
<mml:math display="inline" xmlns:mml="http://www.w3.org/1998/Math/MathML"><mml:mrow><mml:mi>𝐋</mml:mi><mml:msup><mml:mi>𝐋</mml:mi><mml:mi>T</mml:mi></mml:msup></mml:mrow></mml:math></alternatives></inline-formula>
117111
and <inline-formula><alternatives>
118112
<tex-math><![CDATA[\boldsymbol{L}\boldsymbol{D}\boldsymbol{L}^T]]></tex-math>
119113
<mml:math display="inline" xmlns:mml="http://www.w3.org/1998/Math/MathML"><mml:mrow><mml:mi>𝐋</mml:mi><mml:mi>𝐃</mml:mi><mml:msup><mml:mi>𝐋</mml:mi><mml:mi>T</mml:mi></mml:msup></mml:mrow></mml:math></alternatives></inline-formula>
120-
sparse covariance and precision matrix parameterisations with
121-
reparameterised sampling methods, and (5) specialised encoders for
122-
spatial neighbourhood relationships in N-dimensional data.</p>
114+
parameterisations, and specialised encoders for spatial neighbourhood
115+
relationships in N-dimensional data.</p>
116+
<p>The source code is available on GitHub at
117+
<ext-link ext-link-type="uri" xlink:href="https://github.com/cai4cai/torchsparsegradutils">https://github.com/cai4cai/torchsparsegradutils</ext-link>,
118+
with full documentation hosted at
119+
<ext-link ext-link-type="uri" xlink:href="https://torchsparsegradutils.readthedocs.io">https://torchsparsegradutils.readthedocs.io</ext-link>.</p>
123120
</sec>
124121
<sec id="statement-of-need">
125122
<title>Statement of need</title>
@@ -134,15 +131,14 @@ a Creative Commons Attribution 4.0 International License (CC BY
134131
requires backpropagation through sparse linear algebra (matrix
135132
products, triangular solves, and linear system solves). PyTorch’s
136133
default sparse semantics are not designed to preserve user-imposed
137-
sparsity structure during differentiation (PyTorch issue #87448),
138-
which can lead to memory blow-ups and prevent end-to-end optimisation
139-
of sparse probabilistic models.</p>
134+
sparsity structure during differentiation
135+
(<ext-link ext-link-type="uri" xlink:href="https://github.com/pytorch/pytorch/issues/87448">PyTorch
136+
issue #87448</ext-link>), which can lead to memory blow-ups and
137+
prevent end-to-end optimisation of sparse probabilistic models.</p>
140138
<p><monospace>torchsparsegradutils</monospace> addresses this gap by
141139
implementing custom autograd functions for key sparse operators that
142140
return gradients only for stored nonzeros, enabling practical
143-
optimisation of models that rely on fixed sparse structure, such as
144-
sparse multivariate normal distributions with sparse
145-
covariance/precision factors.</p>
141+
optimisation of models that rely on fixed sparse structure.</p>
146142
</sec>
147143
<sec id="state-of-the-field">
148144
<title>State of the field</title>
@@ -153,18 +149,20 @@ a Creative Commons Attribution 4.0 International License (CC BY
153149
PyTorch’s design goal is <italic>dense-equivalent semantics</italic>
154150
for sparse layouts: a guiding invariant is that applying an operation
155151
in sparse form should match applying it in dense form after
156-
conversion, including the backward function (PyTorch issue #87448).
157-
This makes it difficult to learn parameters that are intended to
158-
remain structurally sparse, because gradients may be produced for
159-
implicit zeros, or intermediate computations may densify.</p>
160-
<p>PyTorch also provides <monospace>MaskedTensor</monospace>,
161-
distringuishing specified and unspecified elements in tensors and is
162-
conceptually closer to the constrained-subspace interpretation of
163-
sparsity. However, <monospace>MaskedTensor</monospace> remains at
164-
prototype stage with incomplete operator coverage, and storing a full
165-
boolean mask incurs a significant memory overhead, partially negating
166-
the memory benefits of sparse index-based representations for
167-
large-scale problems.</p>
152+
conversion, including the backward function
153+
(<ext-link ext-link-type="uri" xlink:href="https://github.com/pytorch/pytorch/issues/87448">PyTorch
154+
issue #87448</ext-link>). This makes it difficult to learn parameters
155+
that are intended to remain structurally sparse, because gradients may
156+
be produced for implicit zeros, or intermediate computations may
157+
densify.</p>
158+
<p>PyTorch also provides <monospace>MaskedTensor</monospace>, which
159+
distinguishes specified and unspecified elements and is conceptually
160+
closer to the constrained-subspace interpretation of sparsity.
161+
However, <monospace>MaskedTensor</monospace> remains at prototype
162+
stage with incomplete operator coverage, and storing a full boolean
163+
mask incurs a significant memory overhead, partially negating the
164+
memory benefits of sparse index-based representations for large-scale
165+
problems.</p>
168166
<p>Other libraries provide efficient sparse kernels but do not
169167
directly solve “sparsity-preserving gradients in PyTorch”: SciPy
170168
(<xref alt="Virtanen et al., 2020" rid="ref-scipy" ref-type="bibr">Virtanen
@@ -191,53 +189,35 @@ a Creative Commons Attribution 4.0 International License (CC BY
191189
<p><monospace>torchsparsegradutils</monospace> is built around
192190
<monospace>torch.autograd.Function</monospace> operators that wrap
193191
PyTorch’s forward sparse kernels but override the backward pass to
194-
preserve sparsity for selected inputs. This design keeps the
195-
user-facing API close to standard PyTorch code while making sparsity
196-
preservation an explicit, opt-in choice.</p>
192+
preserve sparsity for selected inputs. This keeps the API close to
193+
standard PyTorch code while making sparsity preservation an explicit,
194+
opt-in choice.</p>
197195
<p>Two design trade-offs shaped the implementation. First, the package
198196
targets <italic>structure-preserving learning</italic> over maximal
199-
operator coverage, as only a focused set of operations (sparse matrix
200-
products, triangular solves, generic sparse solvers) are implemented,
201-
but these are sufficient to support sparse multivariate normal
202-
sampling and sparse solver-based models. Second, for broad
203-
device/backend compatibility, the package combines native PyTorch
204-
implementations (iterative Krylov solvers: CG, BiCGSTAB, LSMR, MINRES)
205-
with optional wrappers to external libraries (CuPy, JAX), allowing
206-
users to trade off portability versus performance.</p>
197+
operator coverage, focusing on sparse matrix products and sparse
198+
solves that support sparse multivariate normal sampling and related
199+
models. Second, it combines native PyTorch implementations (CG,
200+
BiCGSTAB, LSMR, MINRES) with optional CuPy and JAX wrappers so users
201+
can trade off portability and performance.</p>
207202
<p><bold>Build vs. contribute justification.</bold> PyTorch’s current
208-
semantics treat sparse layouts as performance optimisations and
209-
prioritise the dense-equivalence invariant (PyTorch issue #87448). In
210-
contrast, this package intentionally provides
211-
<italic>structure-preserving</italic> backward passes for specific
212-
operators to enable learning with fixed sparsity patterns (e.g.,
213-
sparse triangular factors for covariance/precision). This difference
214-
is semantic (not just implementation), so the functionality is better
215-
delivered as an opt-in external library rather than changing PyTorch’s
216-
default behaviour.</p>
203+
sparse semantics prioritise dense-equivalent behaviour
204+
(<ext-link ext-link-type="uri" xlink:href="https://github.com/pytorch/pytorch/issues/87448">PyTorch
205+
issue #87448</ext-link>). In contrast, this package intentionally
206+
provides structure-preserving backward passes for specific operators
207+
to enable learning with fixed sparsity patterns. Because that is a
208+
semantic choice rather than just an implementation detail, the
209+
functionality is better delivered as an opt-in external library than
210+
as a change to PyTorch defaults.</p>
217211
</sec>
218212
<sec id="research-impact-statement">
219213
<title>Research impact statement</title>
220214
<p>This software provides an opt-in path to sparsity-preserving
221215
gradients for sparse linear algebra in PyTorch, enabling research
222216
prototypes that would otherwise be limited by dense gradients or
223-
densification. The package is currently being used in active research
224-
projects for medical image segmentation, though publications resulting
225-
from this work are still in preparation.</p>
226-
<p>The codebase demonstrates community-readiness through comprehensive
227-
infrastructure: documentation with quickstart guides and API
228-
references, extensive test coverage across all modules, CI/CD
229-
pipelines for automated testing, and an open contribution process via
230-
GitHub issues and pull requests. The codebase has been developed
231-
openly over multiple years with public commit history, releases, and
232-
issue tracking. Benchmark suites comparing solver performance across
233-
problem sizes and sparsity patterns provide reproducible reference
234-
materials.</p>
235-
<p>Given the broad applicability of sparse structured
236-
Gaussians—spanning medical imaging, spatial statistics, geostatistics,
237-
and large-scale probabilistic modelling, we anticipate growing
238-
adoption as the research community increasingly requires
239-
memory-efficient optimisation of high-dimensional probabilistic
240-
models.</p>
217+
densification. The package is already being used in ongoing
218+
medical-image segmentation projects, and the public repository
219+
provides tests, documentation, benchmarks, and issue tracking to
220+
support reuse and extension.</p>
241221
</sec>
242222
<sec id="mathematics">
243223
<title>Mathematics</title>
@@ -374,6 +354,58 @@ a Creative Commons Attribution 4.0 International License (CC BY
374354
matrices by avoiding strict positive definiteness constraints.</p>
375355
</sec>
376356
</sec>
357+
<sec id="usage-examples">
358+
<title>Usage Examples</title>
359+
<p>Short examples are shown below; fuller worked examples are
360+
available in the ReadTheDocs quickstart.</p>
361+
<code language="python">import torch
362+
from torchsparsegradutils import sparse_mm, sparse_generic_solve
363+
from torchsparsegradutils.distributions import SparseMultivariateNormal
364+
from torchsparsegradutils.utils import (
365+
linear_cg,
366+
make_spd_sparse,
367+
rand_sparse,
368+
rand_sparse_tri,
369+
)
370+
371+
n = 100
372+
A = rand_sparse((n, n), nnz=500).requires_grad_(True)
373+
sparse_mm(A, torch.randn(n, 8, requires_grad=True)).sum().backward()
374+
375+
A_spd, _ = make_spd_sparse(n, torch.sparse_coo, torch.float32, torch.int64, &quot;cpu&quot;)
376+
sparse_generic_solve(
377+
A_spd.requires_grad_(True),
378+
torch.randn(n),
379+
solve=linear_cg,
380+
).sum().backward()
381+
382+
L = rand_sparse_tri(
383+
(n, n), nnz=300, upper=False, strict=True
384+
).requires_grad_(True)
385+
SparseMultivariateNormal(
386+
torch.zeros(n), diagonal=torch.rand(n), scale_tril=L
387+
).rsample((10,)).sum().backward()</code>
388+
</sec>
389+
<sec id="benchmarks">
390+
<title>Benchmarks</title>
391+
<p>On the SuiteSparse Rothberg/cfd2 matrix
392+
(<inline-formula><alternatives>
393+
<tex-math><![CDATA[123{,}440 \times 123{,}440]]></tex-math>
394+
<mml:math display="inline" xmlns:mml="http://www.w3.org/1998/Math/MathML"><mml:mrow><mml:mn>123</mml:mn><mml:mo>,</mml:mo><mml:mn>440</mml:mn><mml:mo>×</mml:mo><mml:mn>123</mml:mn><mml:mo>,</mml:mo><mml:mn>440</mml:mn></mml:mrow></mml:math></alternatives></inline-formula>,
395+
3.1M non-zeros), dense baselines and PyTorch’s native COO backward
396+
pass ran out of memory, whereas
397+
<monospace>torchsparsegradutils</monospace> completed sparse
398+
matrix-multiplication backward in about 75 ms using 5.1 GB on one
399+
tested RTX 4090 setup (results vary by hardware). On the same setup,
400+
native COO iterative solvers were up to about
401+
40<inline-formula><alternatives>
402+
<tex-math><![CDATA[\times]]></tex-math>
403+
<mml:math display="inline" xmlns:mml="http://www.w3.org/1998/Math/MathML"><mml:mi>×</mml:mi></mml:math></alternatives></inline-formula>
404+
faster than CuPy wrappers because they avoid sparse-format conversion
405+
overhead; full benchmark scripts and hardware-specific results are
406+
available in the repository and ReadTheDocs benchmark
407+
documentation.</p>
408+
</sec>
377409
<sec id="ai-usage-disclosure">
378410
<title>AI usage disclosure</title>
379411
<p>Generative AI tools were used during development of this software
@@ -391,7 +423,9 @@ a Creative Commons Attribution 4.0 International License (CC BY
391423
<p>We thank the PyTorch development team for foundational sparse
392424
tensor support. We also acknowledge upstream solver implementations
393425
and references used as starting points for iterative methods
394-
(pykrylov, cornellius-gp/linear_operator, pytorch-minimize)
426+
(<ext-link ext-link-type="uri" xlink:href="https://github.com/PythonOptimizers/pykrylov">pykrylov</ext-link>,
427+
<ext-link ext-link-type="uri" xlink:href="https://github.com/cornellius-gp/linear_operator">cornellius-gp/linear_operator</ext-link>,
428+
<ext-link ext-link-type="uri" xlink:href="https://github.com/rfeinman/pytorch-minimize">pytorch-minimize</ext-link>)
395429
(<xref alt="Saad, 2003" rid="ref-saad2003iterative" ref-type="bibr">Saad,
396430
2003</xref>). We thank Floris Laporte for his excellent tutorial on
397431
implementing sparse linear system solvers in PyTorch
@@ -531,7 +565,6 @@ a Creative Commons Attribution 4.0 International License (CC BY
531565
<year iso-8601-date="2018">2018</year>
532566
<volume>31</volume>
533567
<uri>https://arxiv.org/abs/1809.11165</uri>
534-
<pub-id pub-id-type="doi">10.5555/3327757.3327857</pub-id>
535568
</element-citation>
536569
</ref>
537570
<ref id="ref-flaport2020sparse">
@@ -540,8 +573,8 @@ a Creative Commons Attribution 4.0 International License (CC BY
540573
<name><surname>Laporte</surname><given-names>Floris</given-names></name>
541574
</person-group>
542575
<article-title>Solving sparse linear systems in PyTorch</article-title>
543-
<publisher-name>https://blog.flaport.net/solving-sparse-linear-systems-in-pytorch.html</publisher-name>
544576
<year iso-8601-date="2020">2020</year>
577+
<uri>https://blog.flaport.net/solving-sparse-linear-systems-in-pytorch.html</uri>
545578
</element-citation>
546579
</ref>
547580
</ref-list>

paper/paper.pdf

18.5 KB
Binary file not shown.

0 commit comments

Comments
 (0)