Skip to content

Commit c687fe4

Browse files
committed
refactor: MAke TP hierarchical cmaps work for non-uniform refinement
1 parent 3b4aadd commit c687fe4

5 files changed

Lines changed: 87 additions & 48 deletions

src/param/HierarchicalTPParametricAtlas.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ ParentPoint HierarchicalTPParametricAtlas::parentPoint( const topology::Vertex&
4848
const std::array<topology::CellOrEndVertex, 3> fully_unflat_verts = fullyUnflattenVertex( dart_level_cmap, local_d );
4949
const topology::FullyUnflattenedDart fully_unflat_cell = unflattenFull( *mMap->refinementLevels().at( elem_level ), unrefined_d );
5050

51+
const size_t dart_component = dartComponentDirection( unflattenFull( dart_level_cmap, local_d ) );
5152
const size_t ratio = [&](){
5253
size_t out = 1;
5354
for( size_t level_ii = elem_level; level_ii < dart_level; level_ii++ )
54-
out *= mMap->refinementRatios().at( level_ii );
55+
out *= mMap->refinementRatios().at( level_ii ).at( dart_component );
5556
return out;
5657
}();
5758
const Vector6dMax unrefined_lengths = mRefinementLevels.at( elem_level )->parametricLengths( topology::Cell( unrefined_d, mMap->dim() ) );

src/topology/HierarchicalTPCombinatorialMap.cpp

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33
#include <IndexOperations.hpp>
44
#include <ranges>
55
#include <numeric>
6+
#include <Logging.hpp>
67

78
using namespace topology;
89

9-
namespace topology
10-
{
11-
bool checkForNoAncestor( const FullyUnflattenedDart& unflat, const size_t n_darts_per_ancestor )
10+
bool checkForNoAncestor( const FullyUnflattenedDart& unflat, const SmallVector<size_t, 3>& n_darts_per_ancestor )
1211
{
1312
using TPDartPos = TPCombinatorialMap::TPDartPos;
1413

1514
const auto check_dart = [&]( const size_t idx, const bool add_one = false ) {
16-
return ( unflat.unflat_darts.at( idx ).id() + ( add_one ? 1 : 0 ) ) % n_darts_per_ancestor != 0;
15+
return ( unflat.unflat_darts.at( idx ).id() + ( add_one ? 1 : 0 ) ) % n_darts_per_ancestor.at( idx ) != 0;
1716
};
1817
constexpr bool add_one = true;
1918
if( unflat.unflat_darts.size() == 2 ) // 2D case
@@ -108,14 +107,6 @@ bool checkForNoAncestor( const FullyUnflattenedDart& unflat, const size_t n_dart
108107
throw std::invalid_argument( "Bad unflattened dart passed to checkForNoAncestor" );
109108
}
110109

111-
bool checkForNoAncestor( const TPCombinatorialMap& tp_map, const Dart& d, const size_t n_darts_per_ancestor )
112-
{
113-
if( n_darts_per_ancestor == 0 )
114-
return true;
115-
return checkForNoAncestor( unflattenFull( tp_map, d ), n_darts_per_ancestor );
116-
}
117-
}
118-
119110
HierarchicalTPCombinatorialMap::HierarchicalTPCombinatorialMap(
120111
const std::vector<std::shared_ptr<const TPCombinatorialMap>>& refinement_levels,
121112
const std::vector<std::vector<Cell>>& leaf_elements )
@@ -136,16 +127,14 @@ HierarchicalTPCombinatorialMap::HierarchicalTPCombinatorialMap(
136127
for( size_t level = 1; level < refinement_levels.size(); level++ )
137128
{
138129
const util::IndexVec higher_sizes = get_tp_lengths( *refinement_levels.at( level ) );
139-
mRefinementRatios.push_back( higher_sizes.front() / lower_sizes.front() );
140-
141-
if( higher_sizes.front() % lower_sizes.front() != 0 )
142-
throw std::invalid_argument( "Refinement levels must be uniform refinements of each other." );
143-
130+
SmallVector<size_t, 3>& ratios = mRefinementRatios.emplace_back();
144131
for( size_t i = 0; i < lower_sizes.size(); i++ )
145132
{
146-
if( higher_sizes.at( i ) / lower_sizes.at( i ) != mRefinementRatios.back() or higher_sizes.at( i ) % lower_sizes.at( i ) != 0 )
147-
throw std::invalid_argument( "Refinement levels must be uniform refinements of each other." );
133+
if( higher_sizes.at( i ) % lower_sizes.at( i ) != 0 )
134+
throw std::invalid_argument( "Refinement levels must be nested refinements of each other." );
135+
ratios.push_back( higher_sizes.at( i ) / lower_sizes.at( i ) );
148136
}
137+
149138
lower_sizes = higher_sizes;
150139
}
151140

@@ -446,56 +435,65 @@ bool HierarchicalTPCombinatorialMap::iterateDartLineage( const Dart& global_d,
446435

447436
if( dart_level == ancestor_or_descendant_level ) return callback( global_d );
448437

438+
const FullyUnflattenedDart unflat = unflattenFull( *mRefinementLevels.at( dart_level ), local_d );
449439
if( dart_level > ancestor_or_descendant_level )
450440
{
451441
// Call back on a single ancestor dart
452-
const size_t darts_per_ancestor_dart =
442+
const SmallVector<size_t, 3> darts_per_ancestor_dart =
453443
std::reduce( std::next( mRefinementRatios.begin(), ancestor_or_descendant_level ),
454444
std::next( mRefinementRatios.begin(), dart_level ),
455-
1, std::multiplies<>() );
456-
457-
const FullyUnflattenedDart unflat = unflattenFull( *mRefinementLevels.at( dart_level ), local_d );
445+
SmallVector<size_t, 3>( dim(), 1 ),
446+
[&]( const SmallVector<size_t, 3>& a, const SmallVector<size_t, 3>& b ) {
447+
SmallVector<size_t, 3> result( a.size() );
448+
std::transform( a.begin(), a.end(), b.begin(), result.begin(), [&]( const size_t x, const size_t y ) { return x * y; } );
449+
return result;
450+
} );
458451

459452
if( checkForNoAncestor( unflat, darts_per_ancestor_dart ) ) return true;
460453

461454
FullyUnflattenedDart ancestor_unflat( {}, unflat.dart_pos );
462-
std::transform(
463-
unflat.unflat_darts.begin(),
464-
unflat.unflat_darts.end(),
465-
std::back_inserter( ancestor_unflat.unflat_darts ),
466-
[&darts_per_ancestor_dart]( const Dart& d ) { return Dart( d.id() / darts_per_ancestor_dart ); } );
455+
std::transform( unflat.unflat_darts.begin(),
456+
unflat.unflat_darts.end(),
457+
darts_per_ancestor_dart.begin(),
458+
std::back_inserter( ancestor_unflat.unflat_darts ),
459+
[]( const Dart& d, const size_t ratio ) {
460+
return Dart( d.id() / ratio );
461+
} );
467462

468463
const Dart ancestor_dart = flattenFull( *mRefinementLevels.at( ancestor_or_descendant_level ), ancestor_unflat );
469464
return callback( mRanges.toGlobalDart( ancestor_or_descendant_level, ancestor_dart ) );
470465
}
471466
else
472467
{
473468
// Iterate several descendants.
474-
const size_t darts_per_ancestor_dart =
469+
const SmallVector<size_t, 3> darts_per_ancestor_dart =
475470
std::reduce( std::next( mRefinementRatios.begin(), dart_level ),
476471
std::next( mRefinementRatios.begin(), ancestor_or_descendant_level ),
477-
1, std::multiplies<>() );
478-
479-
const FullyUnflattenedDart unflat = unflattenFull( *mRefinementLevels.at( dart_level ), local_d );
472+
SmallVector<size_t, 3>( dim(), 1 ),
473+
[&]( const SmallVector<size_t, 3>& a, const SmallVector<size_t, 3>& b ) {
474+
SmallVector<size_t, 3> result( a.size() );
475+
std::transform( a.begin(), a.end(), b.begin(), result.begin(), [&]( const size_t x, const size_t y ) { return x * y; } );
476+
return result;
477+
} );
480478

481479
FullyUnflattenedDart start_dart( {}, unflat.dart_pos );
482480
std::transform(
483481
unflat.unflat_darts.begin(),
484482
unflat.unflat_darts.end(),
483+
darts_per_ancestor_dart.begin(),
485484
std::back_inserter( start_dart.unflat_darts ),
486-
[&darts_per_ancestor_dart]( const Dart& d ) { return Dart( d.id() * darts_per_ancestor_dart ); } );
485+
[]( const Dart& d, const size_t ratio ) { return Dart( d.id() * ratio ); } );
487486

488487
// Add a series of TP indices to the start_dart, flatten, and call back.
489-
const util::IndexVec lengths( dim(), darts_per_ancestor_dart );
490-
const SmallVector<std::variant<bool, size_t>, 3> direction = tpDirectionAlongTPDartPos( unflat.dart_pos, lengths );
488+
const SmallVector<std::variant<bool, size_t>, 3> direction = tpDirectionAlongTPDartPos( unflat.dart_pos, darts_per_ancestor_dart );
491489
const util::IndexVec order = [this]() {
492490
util::IndexVec order( dim() );
493491
std::iota( order.begin(), order.end(), 0 );
494492
return order;
495493
}();
496494

497495
bool continue_iter = true;
498-
util::iterateTensorProduct( lengths, order, direction, [&]( const util::IndexVec& iv ){
496+
util::iterateTensorProduct( darts_per_ancestor_dart, order, direction, [&]( const util::IndexVec& iv ){
499497
if( not continue_iter ) return;
500498
FullyUnflattenedDart descendant_unflat = start_dart;
501499
std::transform( descendant_unflat.unflat_darts.begin(),
@@ -554,22 +552,21 @@ bool HierarchicalTPCombinatorialMap::iterateChildren( const Cell& local_cell,
554552
if( mRefinementLevels.size() <= descendant_level ) return true;
555553

556554
// Iterate several descendants.
557-
const size_t darts_per_ancestor_dart = mRefinementRatios.at( cell_level );
555+
const auto darts_per_ancestor_dart = mRefinementRatios.at( cell_level );
558556

559557
const FullyUnflattenedDart unflat = unflattenFull( *mRefinementLevels.at( cell_level ), local_cell.dart() );
560558

561559
FullyUnflattenedDart start_dart( {}, unflat.dart_pos );
562560
std::transform(
563561
unflat.unflat_darts.begin(),
564562
unflat.unflat_darts.end(),
563+
darts_per_ancestor_dart.begin(),
565564
std::back_inserter( start_dart.unflat_darts ),
566-
[&darts_per_ancestor_dart]( const Dart& d ) { return Dart( d.id() * darts_per_ancestor_dart ); } );
565+
[]( const Dart& d, const size_t ratio ) { return Dart( d.id() * ratio ); } );
567566

568567
// Add a series of TP indices to the start_dart, flatten, and call back.
569-
const util::IndexVec lengths( dim(), darts_per_ancestor_dart );
570-
571568
bool continue_iter = true;
572-
util::iterateTensorProduct( lengths, [&]( const util::IndexVec& iv ){
569+
util::iterateTensorProduct( darts_per_ancestor_dart, [&]( const util::IndexVec& iv ){
573570
if( not continue_iter ) return;
574571
FullyUnflattenedDart descendant_unflat = start_dart;
575572
std::transform( descendant_unflat.unflat_darts.begin(),

src/topology/HierarchicalTPCombinatorialMap.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace topology
4343

4444
bool iterateLeafDescendants( const Dart& global_d, const std::function<bool( const Dart& )>& callback ) const;
4545

46-
const std::vector<size_t>& refinementRatios() const { return mRefinementRatios; }
46+
const std::vector<SmallVector<size_t, 3>>& refinementRatios() const { return mRefinementRatios; }
4747

4848
bool isUnrefinedLeafDart( const Dart& d ) const
4949
{
@@ -65,7 +65,7 @@ namespace topology
6565
std::vector<bool> mUnrefinedDarts;
6666
std::map<Dart, Dart> mPhiOnes;
6767
std::map<Dart, Dart> mPhiMinusOnes;
68-
std::vector<size_t> mRefinementRatios;
68+
std::vector<SmallVector<size_t, 3>> mRefinementRatios;
6969
};
7070

7171
/// @brief A mutable version of HierarchicalTPCombinatorialMap.
@@ -95,8 +95,6 @@ namespace topology
9595
return mLeafDarts.at( d.id() );
9696
}
9797

98-
size_t refinementRatio( const size_t level ) const { return mRefinementRatios.at( level ); }
99-
10098
/// Expose a protected method from the base class for initialization purposes.
10199
bool iterateDartLineage( const Dart& global_d,
102100
const size_t ancestor_or_descendant_level,
@@ -124,7 +122,5 @@ namespace topology
124122
}
125123
};
126124

127-
bool checkForNoAncestor( const TPCombinatorialMap& tp_map, const Dart& d, const size_t n_darts_per_ancestor );
128-
129125
std::vector<std::vector<Cell>> leafElements( const HierarchicalTPCombinatorialMap& cmap );
130126
}; // namespace topology

src/topology/TPCombinatorialMap.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,49 @@ namespace topology
332332
return out;
333333
}
334334

335+
size_t dartComponentDirection( const FullyUnflattenedDart& unflat_d )
336+
{
337+
if( unflat_d.dart_pos.size() == 1 )
338+
{
339+
switch( unflat_d.dart_pos.at(0) )
340+
{
341+
case TPCombinatorialMap::TPDartPos::DartPos0:
342+
case TPCombinatorialMap::TPDartPos::DartPos2:
343+
return 0;
344+
case TPCombinatorialMap::TPDartPos::DartPos1:
345+
case TPCombinatorialMap::TPDartPos::DartPos3:
346+
return 1;
347+
default:
348+
throw std::runtime_error( "Invalid TP Dart Position for 2d" );
349+
}
350+
}
351+
else if( unflat_d.dart_pos.size() == 2 )
352+
{
353+
switch( unflat_d.dart_pos.at(1) )
354+
{
355+
case TPCombinatorialMap::TPDartPos::DartPos2:
356+
case TPCombinatorialMap::TPDartPos::DartPos4:
357+
return 2;
358+
default:
359+
switch( unflat_d.dart_pos.at(0) )
360+
{
361+
case TPCombinatorialMap::TPDartPos::DartPos0:
362+
case TPCombinatorialMap::TPDartPos::DartPos2:
363+
return 0;
364+
case TPCombinatorialMap::TPDartPos::DartPos1:
365+
case TPCombinatorialMap::TPDartPos::DartPos3:
366+
return 1;
367+
default:
368+
throw std::runtime_error( "Invalid TP Dart Position for 2d" );
369+
}
370+
}
371+
}
372+
else
373+
{
374+
throw std::runtime_error( "Invalid TP Dart Position size" );
375+
}
376+
}
377+
335378
TPCombinatorialMap tensorProductCMapFromComponents( const SmallVector<std::shared_ptr<const CombinatorialMap1d>, 3>& components )
336379
{
337380
if( components.size() == 2 )

src/topology/TPCombinatorialMap.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ namespace topology
104104
SmallVector<TPCombinatorialMap::TPDartPos, 2> dart_pos;
105105
};
106106

107+
size_t dartComponentDirection( const FullyUnflattenedDart& unflat_d );
108+
107109
FullyUnflattenedDart unflattenFull( const TPCombinatorialMap& cmap, const Dart& d );
108110
Dart flattenFull( const TPCombinatorialMap& cmap, const FullyUnflattenedDart& unflat_d );
109111

0 commit comments

Comments
 (0)