Skip to content

Commit 8b4ec0b

Browse files
Lutz Grossclaude
andcommitted
Fix copyWithMask for scalar mask on non-scalar data (issue #47)
Three-part fix in escriptcore/src/Data.cpp maskWorker(): 1. Added early-exit path for Expanded/Constant data with scalar mask, iterating per data-point instead of per-component. 2. Extended shape validation in the Tagged handler to allow scalar masks with matching self/other shapes. 3. Added new else-if branch in the Tagged handler to correctly apply a scalar mask across all tagged values and the default value. Added comprehensive test suite in finley/test/python/run_copyWithMaskOnFinley.py with 91 tests covering ranks 0-4, all combinations of Constant/Tagged/Expanded data storage, and both full-shape and scalar masks. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent e14ce55 commit 8b4ec0b

2 files changed

Lines changed: 1336 additions & 7 deletions

File tree

escriptcore/src/Data.cpp

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -963,19 +963,34 @@ Data::maskWorker(Data& other2, Data& mask2, S sentinel)
963963
unsigned int otherrank=other2.getDataPointRank();
964964
unsigned int maskrank=mask2.getDataPointRank();
965965

966-
if ((selfrank>0) && (otherrank>0) &&(maskrank==0))
967-
{
968-
if (mvec[0]>0) // copy whole object if scalar is >0
966+
if ((selfrank>0) && (otherrank>0) && (maskrank==0) && !isTagged())
967+
{
968+
// Scalar mask applied per data point (Expanded or Constant).
969+
// mvec has one entry per data point; self/ovec have psize entries per data point.
970+
size_t psize = getDataPointSize();
971+
size_t num_dp = getNumDataPoints();
972+
#pragma omp parallel for schedule(static)
973+
for (size_t pt = 0; pt < num_dp; ++pt)
969974
{
970-
copy(other2);
975+
if (mvec[pt] > 0)
976+
{
977+
size_t offset = pt * psize;
978+
for (size_t comp = 0; comp < psize; ++comp)
979+
self[offset + comp] = ovec[offset + comp];
980+
}
971981
}
972982
return;
973983
}
974984
if (isTagged()) // so all objects involved will also be tagged
975985
{
976986
// note the !
977-
if (!((getDataPointShape()==mask2.getDataPointShape()) &&
978-
((other2.getDataPointShape()==mask2.getDataPointShape()) || (otherrank==0))))
987+
// Valid combinations:
988+
// 1. all same shape
989+
// 2. mask same shape as self, other is scalar
990+
// 3. mask is scalar, other same shape as self (new)
991+
if (!((getDataPointShape()==mask2.getDataPointShape() &&
992+
(other2.getDataPointShape()==mask2.getDataPointShape() || otherrank==0)) ||
993+
(maskrank==0 && getDataPointShape()==other2.getDataPointShape())))
979994
{
980995
throw DataException("copyWithMask, shape mismatch.");
981996
}
@@ -1003,7 +1018,7 @@ Data::maskWorker(Data& other2, Data& mask2, S sentinel)
10031018
// now we know that *this has all the required tags but they aren't guaranteed to be in
10041019
// the same order
10051020

1006-
// There are two possibilities: 1. all objects have the same rank. 2. other is a scalar
1021+
// Three cases: 1. all same rank. 2. scalar mask, vector self & other. 3. scalar other.
10071022
if ((selfrank==otherrank) && (otherrank==maskrank))
10081023
{
10091024
for (i=tlookup.begin();i!=tlookup.end();i++)
@@ -1030,6 +1045,30 @@ Data::maskWorker(Data& other2, Data& mask2, S sentinel)
10301045
}
10311046
}
10321047
}
1048+
else if ((selfrank==otherrank) && (maskrank==0))
1049+
{
1050+
// Case 2: scalar mask, vector self and other.
1051+
// mvec has one scalar value per tag block; copy all psize components when mask > 0.
1052+
for (i=tlookup.begin();i!=tlookup.end();i++)
1053+
{
1054+
DataTypes::RealVectorType::size_type toff=tptr->getOffsetForTag(i->first);
1055+
DataTypes::RealVectorType::size_type moff=mptr->getOffsetForTag(i->first);
1056+
DataTypes::RealVectorType::size_type ooff=optr->getOffsetForTag(i->first);
1057+
if (mvec[moff] > 0)
1058+
{
1059+
for (int j=0; j<getDataPointSize(); ++j)
1060+
self[j+toff] = ovec[j+ooff];
1061+
}
1062+
}
1063+
// default value
1064+
if (mvec[mptr->getDefaultOffset()] > 0)
1065+
{
1066+
DataTypes::RealVectorType::size_type tdef=tptr->getDefaultOffset();
1067+
DataTypes::RealVectorType::size_type odef=optr->getDefaultOffset();
1068+
for (int j=0; j<getDataPointSize(); ++j)
1069+
self[j+tdef] = ovec[j+odef];
1070+
}
1071+
}
10331072
else // other is a scalar
10341073
{
10351074
for (i=tlookup.begin();i!=tlookup.end();i++)
@@ -1407,6 +1446,7 @@ Data::toListOfTuples(bool scalarastuple)
14071446
{
14081447
for (count=0;count<npoints;++count)
14091448
{
1449+
14101450
res[count]=vec[count];
14111451
}
14121452
}

0 commit comments

Comments
 (0)