Skip to content

Commit fac0a6d

Browse files
committed
fix: add mutex to VoiceActivityDetection to prevent race between generate() and unload()
VoiceActivityDetection inherits from BaseModel but lacks the thread-safety that VisionModel already provides. When generate() runs on a worker thread (via ModelHostObject::promiseHostFunction) and unload() is called from the JS thread, BaseModel::unload() destroys module_ mid-inference, causing SIGILL/SIGSEGV crashes. This applies the same pattern used by VisionModel: - Add inference_mutex_ member - Lock in generate() to protect forward() calls - Override unload() to acquire the lock before BaseModel::unload() Fixes #1055
1 parent a617188 commit fac0a6d

2 files changed

Lines changed: 19 additions & 0 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,14 @@ VoiceActivityDetection::preprocess(std::span<float> waveform) const {
5454
return frameBuffer;
5555
}
5656

57+
void VoiceActivityDetection::unload() noexcept {
58+
std::scoped_lock lock(inference_mutex_);
59+
BaseModel::unload();
60+
}
61+
5762
std::vector<types::Segment>
5863
VoiceActivityDetection::generate(std::span<float> waveform) const {
64+
std::scoped_lock lock(inference_mutex_);
5965

6066
auto windowedInput = preprocess(waveform);
6167
auto [chunksNumber, remainder] = std::div(

packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <executorch/extension/tensor/tensor.h>
66
#include <executorch/extension/tensor/tensor_ptr.h>
77
#include <executorch/runtime/core/evalue.h>
8+
#include <mutex>
89
#include <span>
910

1011
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
@@ -23,7 +24,19 @@ class VoiceActivityDetection : public BaseModel {
2324
[[nodiscard("Registered non-void function")]] std::vector<types::Segment>
2425
generate(std::span<float> waveform) const;
2526

27+
/**
28+
* @brief Thread-safe unload that waits for any in-flight inference to
29+
* complete.
30+
*
31+
* Mirrors VisionModel::unload(). Without this, BaseModel::unload() can
32+
* destroy module_ while generate() is still calling forward() on a worker
33+
* thread, causing SIGILL / SIGSEGV crashes.
34+
*/
35+
void unload() noexcept;
36+
2637
private:
38+
mutable std::mutex inference_mutex_;
39+
2740
std::vector<std::array<float, constants::kPaddedWindowSize>>
2841
preprocess(std::span<float> waveform) const;
2942
std::vector<types::Segment> postprocess(const std::vector<float> &scores,

0 commit comments

Comments
 (0)