2323// IWYU pragma: end_exports
2424
2525#include " compression/compress.h"
26- #include " hwy/contrib/thread_pool/thread_pool .h"
26+ #include " util/threading_context .h"
2727
2828#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
2929
@@ -98,25 +98,26 @@ void ForeachActivationType3(D d) {
9898
9999// Generates inputs: deterministic, within max SfpStream range.
100100template <typename MatT>
101- MatStorageT<MatT> GenerateMat (const Extents2D& extents,
102- const Allocator& allocator, MatPadding padding,
103- hwy::ThreadPool& pool) {
101+ MatStorageT<MatT> GenerateMat (const Extents2D& extents, MatPadding padding,
102+ ThreadingContext& ctx) {
104103 gcpp::CompressWorkingSet ws;
105- ws.tls .resize (pool. NumWorkers ());
106- MatStorageT<float > raw (" raw" , extents, allocator, MatPadding::kPacked );
107- MatStorageT<MatT> compressed (" mat" , extents, allocator, padding);
104+ ws.tls .resize (ctx. pools . MaxWorkers ());
105+ MatStorageT<float > raw (" raw" , extents, ctx. allocator , MatPadding::kPacked );
106+ MatStorageT<MatT> compressed (" mat" , extents, ctx. allocator , padding);
108107 const float scale = SfpStream::kMax / extents.Area ();
109- pool.Run (0 , extents.rows , [&](const size_t r, size_t thread) {
110- float * HWY_RESTRICT row = raw.Row (r);
111- for (size_t c = 0 ; c < extents.cols ; c++) {
112- float f = static_cast <float >(r * extents.cols + c) * scale;
113- if ((r + c) & 1 ) f = -f; // Also generate some negative values.
114- row[c] = f;
115- }
116- Compress (raw.Row (r), raw.Cols (), ws.tls [thread],
117- MakeSpan (compressed.Row (r), extents.cols ),
118- /* packed_ofs=*/ 0 );
119- });
108+ ParallelFor (ParallelismStrategy::kFlat , extents.rows , ctx, /* cluster_idx=*/ 0 ,
109+ Callers::kTest , [&](size_t r, size_t thread) {
110+ float * HWY_RESTRICT row = raw.Row (r);
111+ for (size_t c = 0 ; c < extents.cols ; c++) {
112+ float f = static_cast <float >(r * extents.cols + c) * scale;
113+ if ((r + c) & 1 )
114+ f = -f; // Also generate some negative values.
115+ row[c] = f;
116+ }
117+ Compress (raw.Row (r), raw.Cols (), ws.tls [thread],
118+ MakeSpan (compressed.Row (r), extents.cols ),
119+ /* packed_ofs=*/ 0 );
120+ });
120121
121122 compressed.SetScale (0 .6f ); // Arbitrary value, different from 1.
122123 return compressed;
@@ -126,25 +127,26 @@ MatStorageT<MatT> GenerateMat(const Extents2D& extents,
126127// `f` swaps `r` and `c`.
127128template <typename MatT>
128129MatStorageT<MatT> GenerateTransposedMat (const Extents2D extents,
129- const Allocator& allocator,
130130 MatPadding padding,
131- hwy::ThreadPool& pool ) {
131+ ThreadingContext& ctx ) {
132132 gcpp::CompressWorkingSet ws;
133- ws.tls .resize (pool. NumWorkers ());
134- MatStorageT<float > raw (" raw" , extents, allocator, MatPadding::kPacked );
135- MatStorageT<MatT> compressed (" trans" , extents, allocator, padding);
133+ ws.tls .resize (ctx. pools . MaxWorkers ());
134+ MatStorageT<float > raw (" raw" , extents, ctx. allocator , MatPadding::kPacked );
135+ MatStorageT<MatT> compressed (" trans" , extents, ctx. allocator , padding);
136136 const float scale = SfpStream::kMax / extents.Area ();
137- pool.Run (0 , extents.rows , [&](const size_t r, size_t thread) {
138- float * HWY_RESTRICT row = raw.Row (r);
139- for (size_t c = 0 ; c < extents.cols ; c++) {
140- float f = static_cast <float >(c * extents.rows + r) * scale;
141- if ((r + c) & 1 ) f = -f; // Also generate some negative values.
142- row[c] = f;
143- }
144- Compress (raw.Row (r), raw.Cols (), ws.tls [thread],
145- MakeSpan (compressed.Row (r), extents.cols ),
146- /* packed_ofs=*/ 0 );
147- });
137+ ParallelFor (ParallelismStrategy::kFlat , extents.rows , ctx, /* cluster_idx=*/ 0 ,
138+ Callers::kTest , [&](size_t r, size_t thread) {
139+ float * HWY_RESTRICT row = raw.Row (r);
140+ for (size_t c = 0 ; c < extents.cols ; c++) {
141+ float f = static_cast <float >(c * extents.rows + r) * scale;
142+ if ((r + c) & 1 )
143+ f = -f; // Also generate some negative values.
144+ row[c] = f;
145+ }
146+ Compress (raw.Row (r), raw.Cols (), ws.tls [thread],
147+ MakeSpan (compressed.Row (r), extents.cols ),
148+ /* packed_ofs=*/ 0 );
149+ });
148150
149151 // Arbitrary value, different from 1, must match `GenerateMat`.
150152 compressed.SetScale (0 .6f );
0 commit comments