diff --git a/include/openPMD/Datatype.hpp b/include/openPMD/Datatype.hpp index a11d3db75f..17cf6b67f4 100644 --- a/include/openPMD/Datatype.hpp +++ b/include/openPMD/Datatype.hpp @@ -294,7 +294,8 @@ template inline constexpr Datatype determineDatatype(T &&val) { (void)val; // don't need this, it only has a name for Doxygen - using T_stripped = std::remove_cv_t>; + using T_stripped = + std::remove_extent_t>>; if constexpr (auxiliary::IsPointer_v) { return determineDatatype>(); @@ -419,6 +420,8 @@ inline size_t toBits(Datatype d) return toBytes(d) * CHAR_BIT; } +constexpr bool isSigned(Datatype d); + /** Compare if a Datatype is a vector type * * @param d Datatype to test @@ -595,14 +598,19 @@ inline std::tuple isInteger() */ template inline bool isSameFloatingPoint(Datatype d) +{ + return isSameFloatingPoint(d, determineDatatype()); +} + +inline bool isSameFloatingPoint(Datatype d1, Datatype d2) { // template - bool tt_is_fp = isFloatingPoint(); + bool tt_is_fp = isFloatingPoint(d1); // Datatype - bool dt_is_fp = isFloatingPoint(d); + bool dt_is_fp = isFloatingPoint(d2); - if (tt_is_fp && dt_is_fp && toBits(d) == toBits(determineDatatype())) + if (tt_is_fp && dt_is_fp && toBits(d1) == toBits(d2)) return true; else return false; @@ -617,15 +625,19 @@ inline bool isSameFloatingPoint(Datatype d) */ template inline bool isSameComplexFloatingPoint(Datatype d) +{ + return isSameComplexFloatingPoint(d, determineDatatype()); +} + +inline bool isSameComplexFloatingPoint(Datatype d1, Datatype d2) { // template - bool tt_is_cfp = isComplexFloatingPoint(); + bool tt_is_cfp = isComplexFloatingPoint(d1); // Datatype - bool dt_is_cfp = isComplexFloatingPoint(d); + bool dt_is_cfp = isComplexFloatingPoint(d2); - if (tt_is_cfp && dt_is_cfp && - toBits(d) == toBits(determineDatatype())) + if (tt_is_cfp && dt_is_cfp && toBits(d1) == toBits(d2)) return true; else return false; @@ -640,17 +652,22 @@ inline bool isSameComplexFloatingPoint(Datatype d) */ template inline bool isSameInteger(Datatype d) +{ + return isSameInteger(d, determineDatatype()); +} + +inline bool isSameInteger(Datatype d1, Datatype d2) { // template bool tt_is_int, tt_is_sig; - std::tie(tt_is_int, tt_is_sig) = isInteger(); + std::tie(tt_is_int, tt_is_sig) = isInteger(d1); // Datatype bool dt_is_int, dt_is_sig; - std::tie(dt_is_int, dt_is_sig) = isInteger(d); + std::tie(dt_is_int, dt_is_sig) = isInteger(d2); if (tt_is_int && dt_is_int && tt_is_sig == dt_is_sig && - toBits(d) == toBits(determineDatatype())) + toBits(d1) == toBits(d2)) return true; else return false; @@ -691,46 +708,15 @@ constexpr bool isChar(Datatype d) template constexpr bool isSameChar(Datatype d); +constexpr bool isSameChar(Datatype d1, Datatype d2); + /** Comparison for two Datatypes * * Besides returning true for the same types, identical implementations on * some platforms, e.g. if long and long long are the same or double and * long double will also return true. */ -inline bool isSame(openPMD::Datatype const d, openPMD::Datatype const e) -{ - // exact same type - if (static_cast(d) == static_cast(e)) - return true; - - bool d_is_vec = isVector(d); - bool e_is_vec = isVector(e); - - // same int - bool d_is_int, d_is_sig; - std::tie(d_is_int, d_is_sig) = isInteger(d); - bool e_is_int, e_is_sig; - std::tie(e_is_int, e_is_sig) = isInteger(e); - if (d_is_int && e_is_int && d_is_vec == e_is_vec && d_is_sig == e_is_sig && - toBits(d) == toBits(e)) - return true; - - // same float - bool d_is_fp = isFloatingPoint(d); - bool e_is_fp = isFloatingPoint(e); - - if (d_is_fp && e_is_fp && d_is_vec == e_is_vec && toBits(d) == toBits(e)) - return true; - - // same complex floating point - bool d_is_cfp = isComplexFloatingPoint(d); - bool e_is_cfp = isComplexFloatingPoint(e); - - if (d_is_cfp && e_is_cfp && d_is_vec == e_is_vec && toBits(d) == toBits(e)) - return true; - - return false; -} +constexpr bool isSame(openPMD::Datatype d, openPMD::Datatype e); /** * @brief basicDatatype Strip openPMD Datatype of std::vector, std::array et. diff --git a/include/openPMD/Datatype.tpp b/include/openPMD/Datatype.tpp index 6685c62f73..e35f2e26b6 100644 --- a/include/openPMD/Datatype.tpp +++ b/include/openPMD/Datatype.tpp @@ -25,6 +25,7 @@ // comment to prevent clang-format from moving this #include up // datatype macros may be included and un-included in other headers #include "openPMD/DatatypeMacros.hpp" +#include "openPMD/auxiliary/TypeTraits.hpp" #include #include // std::void_t @@ -253,6 +254,56 @@ constexpr inline bool isSameChar(Datatype d) { return switchType>(d); } + +namespace detail +{ + struct IsSigned + { + template + static constexpr bool call() + { + if constexpr (auxiliary::IsVector_v || auxiliary::IsArray_v) + { + return call(); + } + else if constexpr (std::is_same_v) + { + return call(); + } + else + { + return std::is_signed_v; + } + } + + static constexpr char const *errorMsg = "IsSigned"; + }; +} // namespace detail + +constexpr inline bool isSigned(Datatype d) +{ + return switchType(d); +} + +constexpr inline bool isSameChar(Datatype d, Datatype e) +{ + return isChar(d) && isChar(e) && isSigned(d) == isSigned(e); +} + +constexpr bool isSame(openPMD::Datatype const d, openPMD::Datatype const e) +{ + return + // exact same type + static_cast(d) == static_cast(e) + // same int + || isSameInteger(d, e) + // same float + || isSameFloatingPoint(d, e) + // same complex floating point + || isSameComplexFloatingPoint(d, e) + // same char + || isSameChar(d, e); +} } // namespace openPMD #include "openPMD/UndefDatatypeMacros.hpp" diff --git a/include/openPMD/backend/PatchRecordComponent.hpp b/include/openPMD/backend/PatchRecordComponent.hpp index fed17dfd3b..eedd84e158 100644 --- a/include/openPMD/backend/PatchRecordComponent.hpp +++ b/include/openPMD/backend/PatchRecordComponent.hpp @@ -122,7 +122,8 @@ template inline void PatchRecordComponent::load(std::shared_ptr data) { Datatype dtype = determineDatatype(); - if (dtype != getDatatype()) + // Attention: Do NOT use operator==(), doesnt work properly on Windows! + if (!isSame(dtype, getDatatype())) throw std::runtime_error( "Type conversion during particle patch loading not yet " "implemented"); @@ -160,10 +161,7 @@ template inline void PatchRecordComponent::store(uint64_t idx, T data) { Datatype dtype = determineDatatype(); - if (dtype != getDatatype() && !isSameInteger(getDatatype()) && - !isSameFloatingPoint(getDatatype()) && - !isSameComplexFloatingPoint(getDatatype()) && - !isSameChar(getDatatype())) + if (!isSame(dtype, getDatatype())) { std::ostringstream oss; oss << "Datatypes of patch data (" << dtype << ") and dataset (" @@ -190,10 +188,7 @@ template inline void PatchRecordComponent::store(T data) { Datatype dtype = determineDatatype(); - if (dtype != getDatatype() && !isSameInteger(getDatatype()) && - !isSameFloatingPoint(getDatatype()) && - !isSameComplexFloatingPoint(getDatatype()) && - !isSameChar(getDatatype())) + if (!isSame(dtype, getDatatype())) { std::ostringstream oss; oss << "Datatypes of patch data (" << dtype << ") and dataset (" diff --git a/src/IO/ADIOS/ADIOS2PreloadAttributes.cpp b/src/IO/ADIOS/ADIOS2PreloadAttributes.cpp index 2b9bb02c2d..a9adbb74d6 100644 --- a/src/IO/ADIOS/ADIOS2PreloadAttributes.cpp +++ b/src/IO/ADIOS/ADIOS2PreloadAttributes.cpp @@ -248,7 +248,7 @@ PreloadAdiosAttributes::getAttribute(std::string const &name) const } AttributeLocation const &location = it->second; Datatype determinedDatatype = determineDatatype(); - if (location.dt != determinedDatatype) + if (!isSame(location.dt, determinedDatatype)) { std::stringstream errorMsg; errorMsg << "[ADIOS2] Wrong datatype for attribute: " << name diff --git a/src/IO/JSON/JSONIOHandlerImpl.cpp b/src/IO/JSON/JSONIOHandlerImpl.cpp index 59541c1e30..551b6c4358 100644 --- a/src/IO/JSON/JSONIOHandlerImpl.cpp +++ b/src/IO/JSON/JSONIOHandlerImpl.cpp @@ -2332,7 +2332,7 @@ auto JSONIOHandlerImpl::verifyDataset( } Datatype dt = stringToDatatype(j["datatype"].get()); VERIFY_ALWAYS( - dt == parameters.dtype, + isSame(dt, parameters.dtype), "[JSON] Read/Write request does not fit the dataset's type"); } catch (json::basic_json::type_error &) diff --git a/src/RecordComponent.cpp b/src/RecordComponent.cpp index c6eef23313..fbf2ad0d88 100644 --- a/src/RecordComponent.cpp +++ b/src/RecordComponent.cpp @@ -657,7 +657,7 @@ void RecordComponent::verifyChunk( if (empty()) throw std::runtime_error( "Chunks cannot be written for an empty RecordComponent."); - if (dtype != getDatatype()) + if (!isSame(dtype, getDatatype())) { std::ostringstream oss; oss << "Datatypes of chunk data (" << dtype @@ -833,21 +833,19 @@ void RecordComponent::loadChunk(std::shared_ptr data, Offset o, Extent e) * JSON/TOML backends as they might implicitly turn a LONG into an INT in a * constant component. The frontend needs to catch such edge cases. * Ref. `if (constant())` branch. + * + * Attention: Do NOT use operator==(), doesnt work properly on Windows! */ - if (dtype != getDatatype() && !constant()) - if (!isSameInteger(getDatatype()) && - !isSameFloatingPoint(getDatatype()) && - !isSameComplexFloatingPoint(getDatatype()) && - !isSameChar(getDatatype())) - { - std::string const data_type_str = datatypeToString(getDatatype()); - std::string const requ_type_str = - datatypeToString(determineDatatype()); - std::string err_msg = - "Type conversion during chunk loading not yet implemented! "; - err_msg += "Data: " + data_type_str + "; Load as: " + requ_type_str; - throw std::runtime_error(err_msg); - } + if (!isSame(dtype, getDatatype()) && !constant()) + { + std::string const data_type_str = datatypeToString(getDatatype()); + std::string const requ_type_str = + datatypeToString(determineDatatype()); + std::string err_msg = + "Type conversion during chunk loading not yet implemented! "; + err_msg += "Data: " + data_type_str + "; Load as: " + requ_type_str; + throw std::runtime_error(err_msg); + } uint8_t dim = getDimensionality(); diff --git a/src/Series.cpp b/src/Series.cpp index a826303193..f7ed6964b8 100644 --- a/src/Series.cpp +++ b/src/Series.cpp @@ -1950,12 +1950,11 @@ void Series::readOneIterationFileBased(std::string const &filePath) readBase(); - using DT = Datatype; aRead.name = "iterationEncoding"; IOHandler()->enqueue(IOTask(this, aRead)); IOHandler()->flush(internal::defaultFlushParams); IterationEncoding encoding_out; - if (*aRead.dtype == DT::STRING) + if (isSame(*aRead.dtype, Datatype::STRING)) { std::string encoding = Attribute(Attribute::from_any, *aRead.m_resource) .get(); @@ -2010,7 +2009,7 @@ void Series::readOneIterationFileBased(std::string const &filePath) aRead.name = "iterationFormat"; IOHandler()->enqueue(IOTask(this, aRead)); IOHandler()->flush(internal::defaultFlushParams); - if (*aRead.dtype == DT::STRING) + if (isSame(*aRead.dtype, Datatype::STRING)) { setWritten(false, Attributable::EnqueueAsynchronously::No); setIterationFormat(Attribute(Attribute::from_any, *aRead.m_resource) diff --git a/src/binding/python/RecordComponent.cpp b/src/binding/python/RecordComponent.cpp index 232df5861d..382783a730 100644 --- a/src/binding/python/RecordComponent.cpp +++ b/src/binding/python/RecordComponent.cpp @@ -489,7 +489,14 @@ inline void store_chunk( check_buffer_is_contiguous(a); - // dtype_from_numpy(a.dtype()) + if (!dtype_to_numpy(r.getDatatype()).is(a.dtype())) + { + std::stringstream err; + err << "Attempting store from Python array of type '" + << dtype_from_numpy(a.dtype()) + << "' into Record Component of type '" << r.getDatatype() << "'."; + throw error::WrongAPIUsage(err.str()); + } switchDatasetType( r.getDatatype(), r, a, offset, extent); } @@ -770,6 +777,15 @@ inline void load_chunk( check_buffer_is_contiguous(a); + if (!dtype_to_numpy(r.getDatatype()).is(a.dtype())) + { + std::stringstream err; + err << "Attempting load into Python array of type '" + << dtype_from_numpy(a.dtype()) + << "' from Record Component of type '" << r.getDatatype() << "'."; + throw error::WrongAPIUsage(err.str()); + } + switchDatasetType( r.getDatatype(), r, a, offset, extent); } diff --git a/test/python/unittest/API/APITest.py b/test/python/unittest/API/APITest.py index c22551a074..00146026f2 100644 --- a/test/python/unittest/API/APITest.py +++ b/test/python/unittest/API/APITest.py @@ -2209,7 +2209,6 @@ def testError(self): def testCustomGeometries(self): DS = io.Dataset - DT = io.Datatype sample_data = np.ones([10], dtype=np.int_) write = io.Series("../samples/custom_geometries_python.json", @@ -2217,25 +2216,25 @@ def testCustomGeometries(self): E = write.iterations[0].meshes["E"] E.set_attribute("geometry", "other:customGeometry") E_x = E["x"] - E_x.reset_dataset(DS(DT.LONG, [10])) + E_x.reset_dataset(DS(np.dtype(np.int_), [10])) E_x[:] = sample_data B = write.iterations[0].meshes["B"] B.set_geometry("customGeometry") B_x = B["x"] - B_x.reset_dataset(DS(DT.LONG, [10])) + B_x.reset_dataset(DS(np.dtype(np.int_), [10])) B_x[:] = sample_data e_energyDensity = write.iterations[0].meshes["e_energyDensity"] e_energyDensity.set_geometry("other:customGeometry") e_energyDensity_x = e_energyDensity[io.Mesh_Record_Component.SCALAR] - e_energyDensity_x.reset_dataset(DS(DT.LONG, [10])) + e_energyDensity_x.reset_dataset(DS(np.dtype(np.int_), [10])) e_energyDensity_x[:] = sample_data e_chargeDensity = write.iterations[0].meshes["e_chargeDensity"] e_chargeDensity.set_geometry(io.Geometry.other) e_chargeDensity_x = e_chargeDensity[io.Mesh_Record_Component.SCALAR] - e_chargeDensity_x.reset_dataset(DS(DT.LONG, [10])) + e_chargeDensity_x.reset_dataset(DS(np.dtype(np.int_), [10])) e_chargeDensity_x[:] = sample_data self.assertTrue(write)