|
237 | 237 | </label> |
238 | 238 | </div> |
239 | 239 | <article role="main"> |
240 | | - <section id="module-torchjd.aggregation.dualproj"> |
241 | | -<span id="dualproj"></span><h1>DualProj<a class="headerlink" href="#module-torchjd.aggregation.dualproj" title="Link to this heading">¶</a></h1> |
242 | | -<dl class="py class"> |
243 | | -<dt class="sig sig-object py" id="torchjd.aggregation.dualproj.DualProj"> |
244 | | -<em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">torchjd.aggregation.dualproj.</span></span><span class="sig-name descname"><span class="pre">DualProj</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">pref_vector</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">norm_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reg_eps</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0001</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">solver</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'quadprog'</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/TorchJD/torchjd/blob/main/src/torchjd/aggregation/dualproj.py#L12-L67"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#torchjd.aggregation.dualproj.DualProj" title="Link to this definition">¶</a></dt> |
245 | | -<dd><p><a class="reference internal" href="../bases/#torchjd.aggregation.bases.Aggregator" title="torchjd.aggregation.bases.Aggregator"><code class="xref py py-class docutils literal notranslate"><span class="pre">Aggregator</span></code></a> that averages the rows of the input matrix, and |
246 | | -projects the result onto the dual cone of the rows of the matrix. This corresponds to the |
247 | | -solution to Equation 11 of <a class="reference external" href="https://proceedings.neurips.cc/paper/2017/file/f87522788a2be2d171666752f97ddebb-Paper.pdf">Gradient Episodic Memory for Continual Learning</a>.</p> |
248 | | -<dl class="field-list simple"> |
249 | | -<dt class="field-odd">Parameters<span class="colon">:</span></dt> |
250 | | -<dd class="field-odd"><ul class="simple"> |
251 | | -<li><p><strong>pref_vector</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://pytorch.org/docs/stable/tensors.html#torch.Tensor" title="(in PyTorch v2.7)"><code class="xref py py-class docutils literal notranslate"><span class="pre">Tensor</span></code></a> | <a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><code class="xref py py-obj docutils literal notranslate"><span class="pre">None</span></code></a></span>) – The preference vector used to combine the rows. If not provided, defaults to |
252 | | -the simple averaging.</p></li> |
253 | | -<li><p><strong>norm_eps</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.13)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – A small value to avoid division by zero when normalizing.</p></li> |
254 | | -<li><p><strong>reg_eps</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/functions.html#float" title="(in Python v3.13)"><code class="xref py py-class docutils literal notranslate"><span class="pre">float</span></code></a></span>) – A small value to add to the diagonal of the gramian of the matrix. Due to |
255 | | -numerical errors when computing the gramian, it might not exactly be positive definite. |
256 | | -This issue can make the optimization fail. Adding <code class="docutils literal notranslate"><span class="pre">reg_eps</span></code> to the diagonal of the gramian |
257 | | -ensures that it is positive definite.</p></li> |
258 | | -<li><p><strong>solver</strong> (<span class="sphinx_autodoc_typehints-type"><a class="reference external" href="https://docs.python.org/3/library/typing.html#typing.Literal" title="(in Python v3.13)"><code class="xref py py-data docutils literal notranslate"><span class="pre">Literal</span></code></a>[<code class="docutils literal notranslate"><span class="pre">'quadprog'</span></code>]</span>) – The solver used to optimize the underlying optimization problem.</p></li> |
259 | | -</ul> |
260 | | -</dd> |
261 | | -</dl> |
262 | | -<div class="admonition-example admonition"> |
263 | | -<p class="admonition-title">Example</p> |
264 | | -<p>Use DualProj to aggregate a matrix.</p> |
265 | | -<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">from</span><span class="w"> </span><span class="nn">torch</span><span class="w"> </span><span class="kn">import</span> <span class="n">tensor</span> |
266 | | -<span class="gp">>>> </span><span class="kn">from</span><span class="w"> </span><span class="nn">torchjd.aggregation</span><span class="w"> </span><span class="kn">import</span> <span class="n">DualProj</span> |
267 | | -<span class="gp">>>></span> |
268 | | -<span class="gp">>>> </span><span class="n">A</span> <span class="o">=</span> <span class="n">DualProj</span><span class="p">()</span> |
269 | | -<span class="gp">>>> </span><span class="n">J</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">4.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">],</span> <span class="p">[</span><span class="mf">6.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">]])</span> |
270 | | -<span class="gp">>>></span> |
271 | | -<span class="gp">>>> </span><span class="n">A</span><span class="p">(</span><span class="n">J</span><span class="p">)</span> |
272 | | -<span class="go">tensor([0.5563, 1.1109, 1.1109])</span> |
273 | | -</pre></div> |
274 | | -</div> |
275 | | -</div> |
276 | | -</dd></dl> |
277 | | - |
| 240 | + <section id="dualproj"> |
| 241 | +<h1>DualProj<a class="headerlink" href="#dualproj" title="Link to this heading">¶</a></h1> |
278 | 242 | </section> |
279 | 243 |
|
280 | 244 | </article> |
|
0 commit comments