@@ -45,10 +45,15 @@ static CmdBufType cmdBufType = CmdBufType::Terrible;
4545#else
4646static CmdBufType cmdBufType = CmdBufType::None;
4747#endif // defined(WEBNN_ENABLE_WIRE)
48- static webnn::wire::WireServer* wireServer = nullptr ;
49- static webnn::wire::WireClient* wireClient = nullptr ;
50- static utils::TerribleCommandBuffer* c2sBuf = nullptr ;
51- static utils::TerribleCommandBuffer* s2cBuf = nullptr ;
48+
49+ class WireHelper {
50+ public:
51+ std::unique_ptr<webnn::wire::WireServer> wireServer;
52+ std::unique_ptr<webnn::wire::WireClient> wireClient;
53+ std::unique_ptr<utils::TerribleCommandBuffer> c2sBuf;
54+ std::unique_ptr<utils::TerribleCommandBuffer> s2cBuf;
55+ };
56+ static WireHelper wireHelper;
5257
5358static wnn::Instance clientInstance;
5459static std::unique_ptr<webnn::native::Instance> nativeInstance;
@@ -70,34 +75,35 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) {
7075 break ;
7176
7277 case CmdBufType::Terrible: {
73- c2sBuf = new utils::TerribleCommandBuffer ();
74- s2cBuf = new utils::TerribleCommandBuffer ();
78+ wireHelper. c2sBuf . reset ( new utils::TerribleCommandBuffer () );
79+ wireHelper. s2cBuf . reset ( new utils::TerribleCommandBuffer () );
7580
7681 webnn::wire::WireServerDescriptor serverDesc = {};
7782 serverDesc.procs = &backendProcs;
78- serverDesc.serializer = s2cBuf;
83+ serverDesc.serializer = wireHelper. s2cBuf . get () ;
7984
80- wireServer = new webnn::wire::WireServer (serverDesc);
81- c2sBuf->SetHandler (wireServer);
85+ wireHelper. wireServer . reset ( new webnn::wire::WireServer (serverDesc) );
86+ wireHelper. c2sBuf ->SetHandler (wireHelper. wireServer . get () );
8287
8388 webnn::wire::WireClientDescriptor clientDesc = {};
84- clientDesc.serializer = c2sBuf;
89+ clientDesc.serializer = wireHelper. c2sBuf . get () ;
8590
86- wireClient = new webnn::wire::WireClient (clientDesc);
91+ wireHelper. wireClient . reset ( new webnn::wire::WireClient (clientDesc) );
8792 procs = webnn::wire::client::GetProcs ();
88- s2cBuf->SetHandler (wireClient);
93+ wireHelper. s2cBuf ->SetHandler (wireHelper. wireClient . get () );
8994
9095#ifdef ENABLE_INJECT_CONTEXT
91- auto contextReservation = wireClient->ReserveContext ();
92- wireServer->InjectContext (backendContext, contextReservation.id ,
93- contextReservation.generation );
96+ auto contextReservation = wireHelper. wireClient ->ReserveContext ();
97+ wireHelper. wireServer ->InjectContext (backendContext, contextReservation.id ,
98+ contextReservation.generation );
9499
95100 context = contextReservation.context ;
96101#else
97102 webnnProcSetProcs (&procs);
98- auto instanceReservation = wireClient->ReserveInstance ();
99- wireServer->InjectInstance (nativeInstance->Get (), instanceReservation.id ,
100- instanceReservation.generation );
103+ webnn::wire::ReservedInstance instanceReservation =
104+ wireHelper.wireClient ->ReserveInstance ();
105+ wireHelper.wireServer ->InjectInstance (nativeInstance->Get (), instanceReservation.id ,
106+ instanceReservation.generation );
101107 // Keep the reference instread of using Acquire.
102108 // TODO:: make the instance in the client as singleton object.
103109 clientInstance = wnn::Instance (instanceReservation.instance );
@@ -115,10 +121,11 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) {
115121
116122void DoFlush () {
117123 if (cmdBufType == CmdBufType::Terrible) {
118- bool c2sSuccess = c2sBuf->Flush ();
119- bool s2cSuccess = s2cBuf->Flush ();
124+ bool c2sSuccess = wireHelper. c2sBuf ->Flush ();
125+ bool s2cSuccess = wireHelper. s2cBuf ->Flush ();
120126
121- ASSERT (c2sSuccess && s2cSuccess);
127+ DAWN_ASSERT (c2sSuccess);
128+ DAWN_ASSERT (s2cSuccess);
122129 }
123130}
124131
@@ -365,25 +372,30 @@ namespace utils {
365372 }
366373
367374 bool LoadAndPreprocessImage (const ExampleBase* example, std::vector<float >& processedPixels) {
375+ DAWN_ASSERT (example != nullptr );
368376 // Read an image.
369377 int imageWidth, imageHeight, imageChannels = 0 ;
370- uint8_t * inputPixels =
371- stbi_load (example->mImagePath .c_str (), &imageWidth, &imageHeight, &imageChannels, 0 );
372- if (inputPixels == 0 ) {
378+ std::unique_ptr< uint8_t > inputPixels (
379+ stbi_load (example->mImagePath .c_str (), &imageWidth, &imageHeight, &imageChannels, 0 )) ;
380+ if (inputPixels == nullptr ) {
373381 dawn::ErrorLog () << " Failed to load and preprocess the image at "
374382 << example->mImagePath ;
375383 return false ;
376384 }
377385 // Resize the image with model's input size
378386 const size_t imageSize = imageHeight * imageWidth * imageChannels;
379- float * floatPixels = ( float *) malloc ( imageSize * sizeof ( float ) );
387+ std::vector< float > floatPixels ( imageSize);
380388 for (size_t i = 0 ; i < imageSize; ++i) {
381- floatPixels[i] = inputPixels[i];
389+ floatPixels[i] = inputPixels. get () [i];
382390 }
383- float * resizedPixels = (float *)malloc (example->mModelHeight * example->mModelWidth *
384- example->mModelChannels * sizeof (float ));
385- stbir_resize_float (floatPixels, imageWidth, imageHeight, 0 , resizedPixels,
386- example->mModelWidth , example->mModelHeight , 0 , example->mModelChannels );
391+ std::vector<float > resizedPixels (example->mModelHeight * example->mModelWidth *
392+ example->mModelChannels );
393+ if (stbir_resize_float (floatPixels.data (), imageWidth, imageHeight, 0 , resizedPixels.data (),
394+ example->mModelWidth , example->mModelHeight , 0 ,
395+ example->mModelChannels ) == 0 ) {
396+ dawn::ErrorLog () << " Failed to resize the image." ;
397+ return false ;
398+ };
387399
388400 // Reoder the image to NCHW/NHWC layout.
389401 for (size_t c = 0 ; c < example->mModelChannels ; ++c) {
@@ -404,8 +416,6 @@ namespace utils {
404416 }
405417 }
406418 }
407- free (resizedPixels);
408- free (floatPixels);
409419 return true ;
410420 }
411421
0 commit comments