Skip to content

Commit 90aa83c

Browse files
authored
common: add bounds check in common_init_result::sampler to prevent segfault on failed model load (#21082)
* common: add bounds check in common_init_result::sampler to prevent segfault on failed model load * Revert a308e58 * Add regression test * Remove regression test for init-fail sampler check
1 parent fcc2d59 commit 90aa83c

2 files changed

Lines changed: 4 additions & 7 deletions

File tree

common/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,9 @@ llama_context * common_init_result::context() {
12431243
}
12441244

12451245
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
1246+
if (seq_id < 0 || seq_id >= (int) pimpl->samplers.size()) {
1247+
return nullptr;
1248+
}
12461249
return pimpl->samplers[seq_id].get();
12471250
}
12481251

tools/completion/completion.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,19 +146,13 @@ int main(int argc, char ** argv) {
146146

147147
ctx = llama_init->context();
148148
model = llama_init->model();
149+
smpl = llama_init->sampler(0);
149150

150151
if (ctx == NULL) {
151152
LOG_ERR("%s: error: unable to create context\n", __func__);
152153
return 1;
153154
}
154155

155-
if (model == NULL) {
156-
LOG_ERR("%s: error: unable to load model\n", __func__);
157-
return 1;
158-
}
159-
160-
smpl = llama_init->sampler(0);
161-
162156
llama_memory_t mem = llama_get_memory(ctx);
163157
const llama_vocab * vocab = llama_model_get_vocab(model);
164158

0 commit comments

Comments
 (0)