1313#include " Support/Pipeline.h"
1414#include " llvm/ADT/APFloat.h"
1515#include " llvm/ADT/APInt.h"
16+ #include " llvm/ADT/bit.h"
1617#include " llvm/Support/Error.h"
1718#include " llvm/Support/raw_ostream.h"
1819#include < cmath>
20+ #include < cstring>
1921#include < sstream>
2022
2123constexpr uint16_t Float16BitSign = 0x8000 ;
@@ -280,33 +282,33 @@ static bool testBufferFloatULP(offloadtest::Buffer *B1, offloadtest::Buffer *B2,
280282 return false ;
281283}
282284
283- template <typename T>
284- static std::string bitPatternAsHex64 (const T &Val,
285- offloadtest::Rule ComparisonRule) {
285+ template <typename T> static uint64_t toBitPattern (const T &Val) {
286286 static_assert (sizeof (T) <= sizeof (uint64_t ), " Type too large for Hex64" );
287+ uint64_t Bits = 0 ;
288+ memcpy (&Bits, &Val, sizeof (T));
289+ return Bits;
290+ }
287291
292+ template <typename T> static std::string formatAsHex (const T &Val) {
288293 std::ostringstream Oss;
289- if (ComparisonRule == offloadtest::Rule::BufferExact)
290- Oss << " 0x" << std::hex << Val;
291- else
292- Oss << std::hexfloat << Val;
294+ Oss << " 0x" << std::hex << toBitPattern (Val);
293295 return Oss.str ();
294296}
295297
296298template <typename T>
297- static void formatBuffer (llvm::ArrayRef<T> Arr, offloadtest::Rule Rule,
299+ static void formatBuffer (llvm::ArrayRef<T> Arr,
298300 llvm::raw_svector_ostream &Result) {
299301 if (Arr.empty ())
300302 return ;
301303
302- Result << " [ " << bitPatternAsHex64 (Arr[0 ], Rule );
304+ Result << " [ " << formatAsHex (Arr[0 ]);
303305 for (size_t I = 1 ; I < Arr.size (); ++I)
304- Result << " , " << bitPatternAsHex64 (Arr[I], Rule );
306+ Result << " , " << formatAsHex (Arr[I]);
305307 Result << " ]" ;
306308}
307309
308310template <typename T>
309- static void formatBufferArray (offloadtest::Buffer *B, offloadtest::Rule Rule,
311+ static void formatBufferArray (offloadtest::Buffer *B,
310312 llvm::raw_svector_ostream &Result) {
311313 assert (B->ArraySize > 1 && " Buffer must be an array to format as array" );
312314 for (const auto &DataPtr : B->Data ) {
@@ -315,62 +317,57 @@ static void formatBufferArray(offloadtest::Buffer *B, offloadtest::Rule Rule,
315317 Result << " - " ;
316318 formatBuffer (llvm::ArrayRef<T>(reinterpret_cast <T *>(DataPtr.get ()),
317319 B->Size / sizeof (T)),
318- Rule, Result);
320+ Result);
319321 }
320322}
321323
322- template <typename T>
323- static std::string formatBuffer (offloadtest::Buffer *B,
324- offloadtest::Rule Rule) {
324+ template <typename T> static std::string formatBuffer (offloadtest::Buffer *B) {
325325 llvm::SmallString<256 > Str;
326326 llvm::raw_svector_ostream Result (Str);
327327
328328 if (B->ArraySize > 1 )
329- formatBufferArray<T>(B, Rule, Result);
329+ formatBufferArray<T>(B, Result);
330330 else
331331 formatBuffer (llvm::ArrayRef<T>(reinterpret_cast <T *>(B->Data .back ().get ()),
332332 B->Size / sizeof (T)),
333- Rule, Result);
333+ Result);
334334
335335 return std::string (Result.str ());
336336}
337337
338- static const std::string getBufferStr (offloadtest::Buffer *B,
339- offloadtest::Rule Rule) {
338+ static const std::string getBufferStr (offloadtest::Buffer *B) {
340339 using DF = offloadtest::DataFormat;
341340 switch (B->Format ) {
342341 case DF::Hex8:
343- return formatBuffer<llvm::yaml::Hex8>(B, Rule );
342+ return formatBuffer<llvm::yaml::Hex8>(B);
344343 case DF::Hex16:
345- return formatBuffer<llvm::yaml::Hex16>(B, Rule );
344+ return formatBuffer<llvm::yaml::Hex16>(B);
346345 case DF::Hex32:
347- return formatBuffer<llvm::yaml::Hex32>(B, Rule );
346+ return formatBuffer<llvm::yaml::Hex32>(B);
348347 case DF::Hex64:
349- return formatBuffer<llvm::yaml::Hex64>(B, Rule );
348+ return formatBuffer<llvm::yaml::Hex64>(B);
350349 case DF::UInt16:
351- return formatBuffer<uint16_t >(B, Rule );
350+ return formatBuffer<uint16_t >(B);
352351 case DF::UInt32:
353- return formatBuffer<uint32_t >(B, Rule );
352+ return formatBuffer<uint32_t >(B);
354353 case DF::UInt64:
355- return formatBuffer<uint64_t >(B, Rule );
354+ return formatBuffer<uint64_t >(B);
356355 case DF::Int16:
357- return formatBuffer<int16_t >(B, Rule );
356+ return formatBuffer<int16_t >(B);
358357 case DF::Int32:
359- return formatBuffer<int32_t >(B, Rule );
358+ return formatBuffer<int32_t >(B);
360359 case DF::Int64:
361- return formatBuffer<int64_t >(B, Rule );
360+ return formatBuffer<int64_t >(B);
362361 case DF::Float16:
363- return formatBuffer<llvm::yaml::Hex16>(B,
364- Rule); // assuming no native float16
362+ return formatBuffer<llvm::yaml::Hex16>(B); // assuming no native float16
365363 case DF::Float32:
366364 case DF::Depth32:
367- return formatBuffer<float >(B, Rule );
365+ return formatBuffer<float >(B);
368366 case DF::Float64:
369- return formatBuffer<double >(B, Rule );
367+ return formatBuffer<double >(B);
370368 case DF::Bool:
371- return formatBuffer<uint32_t >(B,
372- Rule); // Because sizeof(bool) is 1 but HLSL
373- // represents a bool using 4 bytes.
369+ return formatBuffer<uint32_t >(B); // Because sizeof(bool) is 1 but HLSL
370+ // represents a bool using 4 bytes.
374371 }
375372}
376373
@@ -409,18 +406,20 @@ llvm::Error verifyResult(offloadtest::Result R) {
409406 OS << " Got:\n " ;
410407 YAMLOS << *R.ActualPtr ;
411408
412- // Now print exact hex64 representations of each element of the
409+ // Now print exact hex representations of each element of the
413410 // actual and expected buffers.
414411
415- const std::string ExpectedBufferStr =
416- getBufferStr (R.ExpectedPtr , R.ComparisonRule );
417- const std::string ActualBufferStr =
418- getBufferStr (R.ActualPtr , R.ComparisonRule );
412+ if constexpr (llvm::endianness::native == llvm::endianness::little) {
413+ const std::string ExpectedBufferStr = getBufferStr (R.ExpectedPtr );
414+ const std::string ActualBufferStr = getBufferStr (R.ActualPtr );
419415
420- OS << " Full Hex 64bit representation of Expected Buffer Values:\n "
421- << ExpectedBufferStr << " \n " ;
422- OS << " Full Hex 64bit representation of Actual Buffer Values:\n "
423- << ActualBufferStr << " \n " ;
416+ OS << " Full Hex representation of Expected Buffer Values:\n "
417+ << ExpectedBufferStr << " \n " ;
418+ OS << " Full Hex representation of Actual Buffer Values:\n "
419+ << ActualBufferStr << " \n " ;
420+ } else {
421+ OS << " Hex output is not supported on big-endian hosts.\n " ;
422+ }
424423
425424 return llvm::createStringError (Str.c_str ());
426425}
0 commit comments