Skip to content

Commit 37c50b9

Browse files
Discrete adjoint multi-zone for python wrapper (#2787)
* Change nested classes to friend classes * Include discrete adjoint multizone driver in python wrapper * Re-define AdjointProduct and Identity as template classes * address comments --------- Co-authored-by: Pedro Gomes <pcarruscag@gmail.com>
1 parent 54ccbeb commit 37c50b9

4 files changed

Lines changed: 46 additions & 35 deletions

File tree

SU2_CFD/include/drivers/CDiscAdjMultizoneDriver.hpp

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,45 +28,21 @@
2828
#pragma once
2929
#include "CMultizoneDriver.hpp"
3030
#include "../../../Common/include/toolboxes/CQuasiNewtonInvLeastSquares.hpp"
31-
#include "../../../Common/include/linear_algebra/CPreconditioner.hpp"
32-
#include "../../../Common/include/linear_algebra/CMatrixVectorProduct.hpp"
3331
#include "../../../Common/include/linear_algebra/CSysSolve.hpp"
3432

3533
/*!
3634
* \brief Block Gauss-Seidel driver for multizone / multiphysics discrete adjoint problems.
3735
* \ingroup DiscAdj
3836
*/
37+
3938
class CDiscAdjMultizoneDriver : public CMultizoneDriver {
4039

4140
protected:
42-
#ifdef CODI_FORWARD_TYPE
43-
using Scalar = su2double;
44-
#else
45-
using Scalar = passivedouble;
46-
#endif
47-
48-
class AdjointProduct : public CMatrixVectorProduct<Scalar> {
49-
public:
50-
CDiscAdjMultizoneDriver* const driver;
51-
const unsigned short iZone = 0;
52-
mutable unsigned long iInnerIter = 0;
53-
54-
AdjointProduct(CDiscAdjMultizoneDriver* d, unsigned short i) : driver(d), iZone(i) {}
55-
56-
inline void operator()(const CSysVector<Scalar> & u, CSysVector<Scalar> & v) const override {
57-
driver->SetAllSolutions(iZone, true, u);
58-
driver->Iterate(iZone, iInnerIter, true);
59-
driver->GetAllSolutions(iZone, true, v);
60-
v -= u;
61-
++iInnerIter;
62-
}
63-
};
64-
65-
class Identity : public CPreconditioner<Scalar> {
66-
public:
67-
inline bool IsIdentity() const override { return true; }
68-
inline void operator()(const CSysVector<Scalar> & u, CSysVector<Scalar> & v) const override { v = u; }
69-
};
41+
#ifdef CODI_FORWARD_TYPE
42+
using Scalar = su2double;
43+
#else
44+
using Scalar = passivedouble;
45+
#endif
7046

7147
/*!
7248
* \brief Kinds of recordings.
@@ -161,14 +137,14 @@ class CDiscAdjMultizoneDriver : public CMultizoneDriver {
161137
*/
162138
void Run() override;
163139

164-
protected:
165-
166140
/*!
167141
* \brief Run one inner iteration for a given zone.
168142
* \return The result of "monitor".
169143
*/
170144
bool Iterate(unsigned short iZone, unsigned long iInnerIter, bool KrylovMode = false);
171145

146+
protected:
147+
172148
/*!
173149
* \brief Run inner iterations using a Krylov method (GMRES atm).
174150
*/

SU2_CFD/include/drivers/CDriver.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ class CDriver : public CDriverBase {
352352
*/
353353
void PrintDirectResidual(RECORDING kind_recording);
354354

355+
public:
355356
/*!
356357
* \brief Set the solution of all solvers (adjoint or primal) in a zone.
357358
* \param[in] iZone - Index of the zone.
@@ -364,7 +365,7 @@ class CDriver : public CDriverBase {
364365
const auto nPoint = geometry_container[iZone][INST_0][MESH_0]->GetnPoint();
365366
for (auto iSol = 0u, offset = 0u; iSol < MAX_SOLS; ++iSol) {
366367
auto solver = solver_container[iZone][INST_0][MESH_0][iSol];
367-
if (!(solver && (solver->GetAdjoint() == adjoint))) continue;
368+
if (!solver || solver->GetAdjoint() != adjoint) continue;
368369
for (auto iPoint = 0ul; iPoint < nPoint; ++iPoint)
369370
for (auto iVar = 0ul; iVar < solver->GetnVar(); ++iVar)
370371
if (!Old) {
@@ -395,7 +396,7 @@ class CDriver : public CDriverBase {
395396
const auto nPoint = geometry_container[iZone][INST_0][MESH_0]->GetnPoint();
396397
for (auto iSol = 0u, offset = 0u; iSol < MAX_SOLS; ++iSol) {
397398
auto solver = solver_container[iZone][INST_0][MESH_0][iSol];
398-
if (!(solver && (solver->GetAdjoint() == adjoint))) continue;
399+
if (!solver || solver->GetAdjoint() != adjoint) continue;
399400
const auto& sol = solver->GetNodes()->GetSolution();
400401
for (auto iPoint = 0ul; iPoint < nPoint; ++iPoint)
401402
for (auto iVar = 0ul; iVar < solver->GetnVar(); ++iVar)
@@ -419,7 +420,6 @@ class CDriver : public CDriverBase {
419420
return nVar;
420421
}
421422

422-
public:
423423
/*!
424424
* \brief Launch the computation for all zones and all physics.
425425
*/

SU2_CFD/src/drivers/CDiscAdjMultizoneDriver.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,39 @@
3030
#include "../../include/output/COutputFactory.hpp"
3131
#include "../../include/output/COutput.hpp"
3232
#include "../../include/iteration/CIterationFactory.hpp"
33+
#include "../../../Common/include/linear_algebra/CPreconditioner.hpp"
34+
#include "../../../Common/include/linear_algebra/CMatrixVectorProduct.hpp"
35+
36+
namespace {
37+
#ifdef CODI_FORWARD_TYPE
38+
using Scalar = su2double;
39+
#else
40+
using Scalar = passivedouble;
41+
#endif
42+
43+
class AdjointProduct : public CMatrixVectorProduct<Scalar> {
44+
public:
45+
CDiscAdjMultizoneDriver* const driver;
46+
const unsigned short iZone = 0;
47+
mutable unsigned long iInnerIter = 0;
48+
49+
AdjointProduct(CDiscAdjMultizoneDriver* d, unsigned short i) : driver(d), iZone(i) {}
50+
51+
inline void operator()(const CSysVector<Scalar>& u, CSysVector<Scalar>& v) const override {
52+
driver->SetAllSolutions(iZone, true, u);
53+
driver->Iterate(iZone, iInnerIter, true);
54+
driver->GetAllSolutions(iZone, true, v);
55+
v -= u;
56+
++iInnerIter;
57+
}
58+
};
59+
60+
class Identity : public CPreconditioner<Scalar> {
61+
public:
62+
inline bool IsIdentity() const override { return true; }
63+
inline void operator()(const CSysVector<Scalar>& u, CSysVector<Scalar>& v) const override { v = u; }
64+
};
65+
} // namespace
3366

3467
CDiscAdjMultizoneDriver::CDiscAdjMultizoneDriver(char* confFile,
3568
unsigned short val_nZone,

SU2_PY/pySU2/pySU2ad.i

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ threads="1"
3939
%{
4040
#include "../../Common/include/containers/CPyWrapperMatrixView.hpp"
4141
#include "../../SU2_CFD/include/drivers/CDiscAdjSinglezoneDriver.hpp"
42+
#include "../../SU2_CFD/include/drivers/CDiscAdjMultizoneDriver.hpp"
4243
#include "../../SU2_CFD/include/drivers/CDriver.hpp"
4344
#include "../../SU2_CFD/include/drivers/CDriverBase.hpp"
4445
#include "../../SU2_CFD/include/drivers/CMultizoneDriver.hpp"
@@ -98,4 +99,5 @@ const unsigned int ZONE_1 = 1; /*!< \brief Definition of the first grid domain.
9899
%include "../../SU2_CFD/include/drivers/CSinglezoneDriver.hpp"
99100
%include "../../SU2_CFD/include/drivers/CMultizoneDriver.hpp"
100101
%include "../../SU2_CFD/include/drivers/CDiscAdjSinglezoneDriver.hpp"
102+
%include "../../SU2_CFD/include/drivers/CDiscAdjMultizoneDriver.hpp"
101103
%include "../../SU2_DEF/include/drivers/CDiscAdjDeformationDriver.hpp"

0 commit comments

Comments
 (0)