Skip to content

Commit d7c6264

Browse files
authored
Fix data race on should_stop_ flag in LLM runner (#18652)
should_stop_ is written from the caller thread via stop() and read from the inference thread in the generate loop. A plain bool without synchronization is undefined behavior per the C++ standard and can cause the compiler to optimize away the cross-thread visibility on ARM targets. Change bool to std::atomic<bool> with relaxed memory ordering, which is sufficient for a simple cancellation flag and has negligible overhead.
1 parent 56d6e4d commit d7c6264

4 files changed

Lines changed: 6 additions & 12 deletions

File tree

extension/llm/runner/text_decoder_runner.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,6 @@ class ET_EXPERIMENTAL TextDecoderRunner {
6969
return method_name_;
7070
}
7171

72-
inline void stop() {
73-
should_stop_ = true;
74-
}
75-
7672
/**
7773
* Sample the next token from the logits tensor.
7874
* @param logits_tensor The logits tensor.
@@ -98,7 +94,6 @@ class ET_EXPERIMENTAL TextDecoderRunner {
9894
Module* module_;
9995
IOManager* io_manager_;
10096
std::string method_name_;
101-
bool should_stop_{false};
10297
};
10398

10499
} // namespace llm

extension/llm/runner/text_llm_runner.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ Error TextLLMRunner::generate(
108108
// return a response token.
109109

110110
stats_->inference_start_ms = time_in_ms();
111-
shouldStop_ = false;
112111

113112
int64_t max_context_len = metadata_.at(kMaxContextLen);
114113

extension/llm/runner/text_llm_runner.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,6 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
161161
void stop() override;
162162

163163
private:
164-
bool shouldStop_{false};
165-
166164
// Components
167165
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;
168166
std::unordered_map<std::string, int64_t> metadata_;

extension/llm/runner/text_token_generator.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
// Generate tokens in a loop.
1010
#pragma once
1111

12+
#include <atomic>
13+
1214
#include <executorch/extension/llm/runner/stats.h>
1315
#include <executorch/extension/llm/runner/text_decoder_runner.h>
1416
#include <executorch/extension/tensor/tensor.h>
@@ -95,7 +97,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
9597
resize_tensor_ptr(tokens_managed, token_shape));
9698
}
9799

98-
should_stop_ = false;
100+
should_stop_.store(false, std::memory_order_relaxed);
99101

100102
// Generate our tokens
101103
while (pos < start_pos + max_new_tokens) {
@@ -136,7 +138,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
136138
}
137139
token_callback(std::move(*decode_result));
138140

139-
if (should_stop_) {
141+
if (should_stop_.load(std::memory_order_relaxed)) {
140142
break;
141143
}
142144

@@ -154,7 +156,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
154156
* Stop the generation loop.
155157
*/
156158
inline void stop() {
157-
should_stop_ = true;
159+
should_stop_.store(true, std::memory_order_relaxed);
158160
}
159161

160162
/**
@@ -188,7 +190,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
188190
bool ignore_eos_ = false;
189191

190192
// state machine
191-
bool should_stop_ = false;
193+
std::atomic<bool> should_stop_{false};
192194

193195
// stats
194196
Stats* stats_;

0 commit comments

Comments
 (0)