Skip to content

Commit d5d0d25

Browse files
committed
1 parent b37fd12 commit d5d0d25

27 files changed

Lines changed: 45 additions & 45 deletions

File tree

latest/_sources/examples/rnn.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ descent can be leveraged to enhance optimization.
3434
.. note::
3535
At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and
3636
``torch.nn.RNN`` when running on CUDA (see `this issue
37-
<https://github.com/TorchJD/torchjd/issues/220>`_ for more info), so we advise to set the
37+
<https://github.com/SimplexLab/TorchJD/issues/220>`_ for more info), so we advise to set the
3838
``parallel_chunk_size`` to ``1`` to avoid using ``torch.vmap``. To improve performance, you can
3939
check whether ``parallel_chunk_size=None`` (maximal parallelization) works on your side.

latest/_sources/index.rst.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ of the batch and per task, in the context of multi-task learning. We call this
5454
:doc:`Instance-Wise Risk Multi-Task Learning <examples/iwmtl>` (IWMTL).
5555

5656
TorchJD is open-source, under MIT License. The source code is available on
57-
`GitHub <https://github.com/TorchJD/torchjd>`_.
57+
`GitHub <https://github.com/SimplexLab/TorchJD>`_.
5858

5959
.. toctree::
6060
:caption: Getting Started

latest/docs/aggregation/aligned_mtl/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
<h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this heading"></a></h1>
296296
<dl class="py class">
297297
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTL">
298-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L43-L76"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL" title="Link to this definition"></a></dt>
298+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTL</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L43-L76"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTL" title="Link to this definition"></a></dt>
299299
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
300300
<a class="reference external" href="https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf">Independent Component Alignment for Multi-Task Learning</a>.</p>
301301
<dl class="field-list simple">
@@ -318,7 +318,7 @@ <h1>Aligned-MTL<a class="headerlink" href="#aligned-mtl" title="Link to this hea
318318

319319
<dl class="py class">
320320
<dt class="sig sig-object py" id="torchjd.aggregation.AlignedMTLWeighting">
321-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTLWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L79-L138"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting" title="Link to this definition"></a></dt>
321+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">AlignedMTLWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_mode</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'min'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_aligned_mtl.py#L79-L138"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.AlignedMTLWeighting" title="Link to this definition"></a></dt>
322322
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> giving the weights of
323323
<a class="reference internal" href="#torchjd.aggregation.AlignedMTL" title="torchjd.aggregation.AlignedMTL"><code class="xref py py-class docutils literal notranslate"><span class="pre">AlignedMTL</span></code></a>.</p>
324324
<dl class="field-list simple">

latest/docs/aggregation/cagrad/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
<h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading"></a></h1>
296296
<dl class="py class">
297297
<dt class="sig sig-object py" id="torchjd.aggregation.CAGrad">
298-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_cagrad.py#L21-L50"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition"></a></dt>
298+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGrad</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L21-L50"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGrad" title="Link to this definition"></a></dt>
299299
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Algorithm 1 of
300300
<a class="reference external" href="https://arxiv.org/pdf/2110.14048.pdf">Conflict-Averse Gradient Descent for Multi-task Learning</a>.</p>
301301
<dl class="field-list simple">
@@ -317,7 +317,7 @@ <h1>CAGrad<a class="headerlink" href="#cagrad" title="Link to this heading">¶</
317317

318318
<dl class="py class">
319319
<dt class="sig sig-object py" id="torchjd.aggregation.CAGradWeighting">
320-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGradWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_cagrad.py#L53-L106"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting" title="Link to this definition"></a></dt>
320+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">CAGradWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">c</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_cagrad.py#L53-L106"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.CAGradWeighting" title="Link to this definition"></a></dt>
321321
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> giving the weights of
322322
<a class="reference internal" href="#torchjd.aggregation.CAGrad" title="torchjd.aggregation.CAGrad"><code class="xref py py-class docutils literal notranslate"><span class="pre">CAGrad</span></code></a>.</p>
323323
<dl class="field-list simple">

latest/docs/aggregation/config/index.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
<h1>ConFIG<a class="headerlink" href="#config" title="Link to this heading"></a></h1>
296296
<dl class="py class">
297297
<dt class="sig sig-object py" id="torchjd.aggregation.ConFIG">
298-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">ConFIG</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_config.py#L39-L76"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition"></a></dt>
298+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">ConFIG</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_config.py#L39-L76"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConFIG" title="Link to this definition"></a></dt>
299299
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> as defined in Equation 2 of <a class="reference external" href="https://arxiv.org/pdf/2408.11104">ConFIG:
300300
Towards Conflict-free Training of Physics Informed Neural Networks</a>.</p>
301301
<dl class="field-list simple">

latest/docs/aggregation/constant/index.html

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
<h1>Constant<a class="headerlink" href="#constant" title="Link to this heading"></a></h1>
296296
<dl class="py class">
297297
<dt class="sig sig-object py" id="torchjd.aggregation.Constant">
298-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">Constant</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">weights</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_constant.py#L10-L27"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.Constant" title="Link to this definition"></a></dt>
298+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">Constant</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">weights</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_constant.py#L10-L27"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.Constant" title="Link to this definition"></a></dt>
299299
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Aggregator" title="torchjd.aggregation._aggregator_bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> that makes a linear combination of
300300
the rows of the provided matrix, with constant, pre-determined weights.</p>
301301
<dl class="field-list simple">
@@ -307,7 +307,7 @@ <h1>Constant<a class="headerlink" href="#constant" title="Link to this heading">
307307

308308
<dl class="py class">
309309
<dt class="sig sig-object py" id="torchjd.aggregation.ConstantWeighting">
310-
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">ConstantWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">weights</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/_constant.py#L30-L57"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConstantWeighting" title="Link to this definition"></a></dt>
310+
<span class="property"><span class="k"><span class="pre">class</span></span><span class="w"> </span></span><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.</span></span><span class="sig-name descname"><span class="pre">ConstantWeighting</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">weights</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/SimplexLab/TorchJD/blob/main/src/torchjd/aggregation/_constant.py#L30-L57"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.ConstantWeighting" title="Link to this definition"></a></dt>
311311
<dd><p><a class="reference internal" href="../#torchjd.aggregation.Weighting" title="torchjd.aggregation._weighting_bases.Weighting"><code class="xref py py-class docutils literal notranslate"><span class="pre">Weighting</span></code></a> that returns constant, pre-determined
312312
weights.</p>
313313
<dl class="field-list simple">

0 commit comments

Comments
 (0)