Skip to content

Commit c41280a

Browse files
authored
Merge pull request #275 from mingmingtasd/memory_leak
Fix dml/dmlx memory leaks for C++ examples
2 parents 31d48a4 + 7b070e1 commit c41280a

1 file changed

Lines changed: 42 additions & 32 deletions

File tree

examples/SampleUtils.cpp

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,15 @@ static CmdBufType cmdBufType = CmdBufType::Terrible;
4545
#else
4646
static CmdBufType cmdBufType = CmdBufType::None;
4747
#endif // defined(WEBNN_ENABLE_WIRE)
48-
static webnn::wire::WireServer* wireServer = nullptr;
49-
static webnn::wire::WireClient* wireClient = nullptr;
50-
static utils::TerribleCommandBuffer* c2sBuf = nullptr;
51-
static utils::TerribleCommandBuffer* s2cBuf = nullptr;
48+
49+
class WireHelper {
50+
public:
51+
std::unique_ptr<webnn::wire::WireServer> wireServer;
52+
std::unique_ptr<webnn::wire::WireClient> wireClient;
53+
std::unique_ptr<utils::TerribleCommandBuffer> c2sBuf;
54+
std::unique_ptr<utils::TerribleCommandBuffer> s2cBuf;
55+
};
56+
static WireHelper wireHelper;
5257

5358
static wnn::Instance clientInstance;
5459
static std::unique_ptr<webnn::native::Instance> nativeInstance;
@@ -70,34 +75,35 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) {
7075
break;
7176

7277
case CmdBufType::Terrible: {
73-
c2sBuf = new utils::TerribleCommandBuffer();
74-
s2cBuf = new utils::TerribleCommandBuffer();
78+
wireHelper.c2sBuf.reset(new utils::TerribleCommandBuffer());
79+
wireHelper.s2cBuf.reset(new utils::TerribleCommandBuffer());
7580

7681
webnn::wire::WireServerDescriptor serverDesc = {};
7782
serverDesc.procs = &backendProcs;
78-
serverDesc.serializer = s2cBuf;
83+
serverDesc.serializer = wireHelper.s2cBuf.get();
7984

80-
wireServer = new webnn::wire::WireServer(serverDesc);
81-
c2sBuf->SetHandler(wireServer);
85+
wireHelper.wireServer.reset(new webnn::wire::WireServer(serverDesc));
86+
wireHelper.c2sBuf->SetHandler(wireHelper.wireServer.get());
8287

8388
webnn::wire::WireClientDescriptor clientDesc = {};
84-
clientDesc.serializer = c2sBuf;
89+
clientDesc.serializer = wireHelper.c2sBuf.get();
8590

86-
wireClient = new webnn::wire::WireClient(clientDesc);
91+
wireHelper.wireClient.reset(new webnn::wire::WireClient(clientDesc));
8792
procs = webnn::wire::client::GetProcs();
88-
s2cBuf->SetHandler(wireClient);
93+
wireHelper.s2cBuf->SetHandler(wireHelper.wireClient.get());
8994

9095
#ifdef ENABLE_INJECT_CONTEXT
91-
auto contextReservation = wireClient->ReserveContext();
92-
wireServer->InjectContext(backendContext, contextReservation.id,
93-
contextReservation.generation);
96+
auto contextReservation = wireHelper.wireClient->ReserveContext();
97+
wireHelper.wireServer->InjectContext(backendContext, contextReservation.id,
98+
contextReservation.generation);
9499

95100
context = contextReservation.context;
96101
#else
97102
webnnProcSetProcs(&procs);
98-
auto instanceReservation = wireClient->ReserveInstance();
99-
wireServer->InjectInstance(nativeInstance->Get(), instanceReservation.id,
100-
instanceReservation.generation);
103+
webnn::wire::ReservedInstance instanceReservation =
104+
wireHelper.wireClient->ReserveInstance();
105+
wireHelper.wireServer->InjectInstance(nativeInstance->Get(), instanceReservation.id,
106+
instanceReservation.generation);
101107
// Keep the reference instread of using Acquire.
102108
// TODO:: make the instance in the client as singleton object.
103109
clientInstance = wnn::Instance(instanceReservation.instance);
@@ -115,10 +121,11 @@ wnn::Context CreateCppContext(wnn::ContextOptions const* options) {
115121

116122
void DoFlush() {
117123
if (cmdBufType == CmdBufType::Terrible) {
118-
bool c2sSuccess = c2sBuf->Flush();
119-
bool s2cSuccess = s2cBuf->Flush();
124+
bool c2sSuccess = wireHelper.c2sBuf->Flush();
125+
bool s2cSuccess = wireHelper.s2cBuf->Flush();
120126

121-
ASSERT(c2sSuccess && s2cSuccess);
127+
DAWN_ASSERT(c2sSuccess);
128+
DAWN_ASSERT(s2cSuccess);
122129
}
123130
}
124131

@@ -365,25 +372,30 @@ namespace utils {
365372
}
366373

367374
bool LoadAndPreprocessImage(const ExampleBase* example, std::vector<float>& processedPixels) {
375+
DAWN_ASSERT(example != nullptr);
368376
// Read an image.
369377
int imageWidth, imageHeight, imageChannels = 0;
370-
uint8_t* inputPixels =
371-
stbi_load(example->mImagePath.c_str(), &imageWidth, &imageHeight, &imageChannels, 0);
372-
if (inputPixels == 0) {
378+
std::unique_ptr<uint8_t> inputPixels(
379+
stbi_load(example->mImagePath.c_str(), &imageWidth, &imageHeight, &imageChannels, 0));
380+
if (inputPixels == nullptr) {
373381
dawn::ErrorLog() << "Failed to load and preprocess the image at "
374382
<< example->mImagePath;
375383
return false;
376384
}
377385
// Resize the image with model's input size
378386
const size_t imageSize = imageHeight * imageWidth * imageChannels;
379-
float* floatPixels = (float*)malloc(imageSize * sizeof(float));
387+
std::vector<float> floatPixels(imageSize);
380388
for (size_t i = 0; i < imageSize; ++i) {
381-
floatPixels[i] = inputPixels[i];
389+
floatPixels[i] = inputPixels.get()[i];
382390
}
383-
float* resizedPixels = (float*)malloc(example->mModelHeight * example->mModelWidth *
384-
example->mModelChannels * sizeof(float));
385-
stbir_resize_float(floatPixels, imageWidth, imageHeight, 0, resizedPixels,
386-
example->mModelWidth, example->mModelHeight, 0, example->mModelChannels);
391+
std::vector<float> resizedPixels(example->mModelHeight * example->mModelWidth *
392+
example->mModelChannels);
393+
if (stbir_resize_float(floatPixels.data(), imageWidth, imageHeight, 0, resizedPixels.data(),
394+
example->mModelWidth, example->mModelHeight, 0,
395+
example->mModelChannels) == 0) {
396+
dawn::ErrorLog() << "Failed to resize the image.";
397+
return false;
398+
};
387399

388400
// Reoder the image to NCHW/NHWC layout.
389401
for (size_t c = 0; c < example->mModelChannels; ++c) {
@@ -404,8 +416,6 @@ namespace utils {
404416
}
405417
}
406418
}
407-
free(resizedPixels);
408-
free(floatPixels);
409419
return true;
410420
}
411421

0 commit comments

Comments
 (0)