Skip to content

Commit 6dae13f

Browse files
authored
fix: max hit (#1193)
max_hit should only be restricted to PipelineTask --- https://github.com/MaaEnd/MaaEnd/actions/runs/22839301359?pr=1112 ## Summary by Sourcery 将 `max_hit` 的跟踪和限制仅应用于 `PipelineTask` 的识别流程,而不是共享的 `TaskBase` 实现。 Bug Fixes: - 确保只有在 `PipelineTask` 中达到 `max_hit` 时才跳过识别尝试,从而避免对其他任务类型施加意外的限制。 Enhancements: - 当在 `PipelineTask` 中达到 `max_hit` 以及发生识别命中时,记录 debug 和 info 日志消息,以提高可观测性。 <details> <summary>Original summary in English</summary> ## Summary by Sourcery Limit max_hit tracking and enforcement to PipelineTask recognition flow instead of the shared TaskBase implementation. Bug Fixes: - Ensure recognition attempts are skipped when max_hit is reached only within PipelineTask, preventing unintended limits on other task types. Enhancements: - Log debug and info messages when max_hit is reached and when recognition hits occur within PipelineTask for better observability. </details>
1 parent b63eeff commit 6dae13f

4 files changed

Lines changed: 26 additions & 18 deletions

File tree

source/MaaFramework/Task/Context.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,16 @@ bool& Context::need_to_stop()
393393
return *need_to_stop_;
394394
}
395395

396+
bool Context::check_hit_count(const PipelineData& data)
397+
{
398+
size_t current_hit = get_hit_count(data.name);
399+
if (current_hit >= static_cast<size_t>(data.max_hit)) {
400+
LogDebug << "max_hit reached" << VAR(data.name) << VAR(current_hit) << VAR(data.max_hit);
401+
return false;
402+
}
403+
return true;
404+
}
405+
396406
void Context::increment_hit_count(const std::string& node_name)
397407
{
398408
task_state_->hit_count[node_name]++;

source/MaaFramework/Task/Context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class Context
6969
std::vector<cv::Mat> get_images(const std::vector<std::string>& names);
7070

7171
bool& need_to_stop();
72+
bool check_hit_count(const PipelineData& data);
7273
void increment_hit_count(const std::string& node_name);
7374

7475
private:

source/MaaFramework/Task/PipelineTask.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,22 @@ RecoResult PipelineTask::recognize_list(const cv::Mat& image, const std::vector<
282282
recognizer.prefetch_batch_ocr(batch_plan->entries);
283283
}
284284

285+
if (!pipeline_data.enabled) {
286+
LogDebug << "node disabled" << pipeline_data.name << VAR(pipeline_data.enabled);
287+
return { };
288+
}
289+
290+
if (!context_->check_hit_count(pipeline_data)) {
291+
continue;
292+
}
293+
285294
RecoResult result = run_recognition(image, pipeline_data, ocr_cache);
286295

296+
if (result.box) {
297+
LogInfo << "reco hit" << VAR(result.name) << VAR(result.box);
298+
context_->increment_hit_count(pipeline_data.name);
299+
}
300+
287301
if (context_->need_to_stop()) {
288302
LogWarn << "need_to_stop";
289303
break;
@@ -323,8 +337,7 @@ std::optional<PipelineTask::BatchOCRPlan> PipelineTask::prepare_batch_ocr(const
323337
continue;
324338
}
325339

326-
size_t current_hit = context_->get_hit_count(data.name);
327-
if (current_hit >= static_cast<size_t>(data.max_hit)) {
340+
if (!context_->check_hit_count(data)) {
328341
continue;
329342
}
330343

source/MaaFramework/Task/TaskBase.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,6 @@ RecoResult TaskBase::run_recognition(const cv::Mat& image, const PipelineData& d
6161
return { };
6262
}
6363

64-
if (!data.enabled) {
65-
LogDebug << "node disabled" << data.name << VAR(data.enabled);
66-
return { };
67-
}
68-
69-
size_t current_hit = context_->get_hit_count(data.name);
70-
if (current_hit >= static_cast<size_t>(data.max_hit)) {
71-
LogDebug << "max_hit reached" << VAR(data.name) << VAR(current_hit) << VAR(data.max_hit);
72-
return { };
73-
}
74-
7564
Recognizer recognizer(tasker_, *context_, image, std::move(ocr_cache));
7665

7766
json::value cb_detail {
@@ -93,11 +82,6 @@ RecoResult TaskBase::run_recognition(const cv::Mat& image, const PipelineData& d
9382
cb_detail["reco_details"] = result;
9483
notify(result.box ? MaaMsg_Node_Recognition_Succeeded : MaaMsg_Node_Recognition_Failed, cb_detail);
9584

96-
if (result.box) {
97-
LogInfo << "reco hit" << VAR(result.name) << VAR(result.box);
98-
context_->increment_hit_count(data.name);
99-
}
100-
10185
return result;
10286
}
10387

0 commit comments

Comments
 (0)