Skip to content

Commit bf76ac7

Browse files
authored
common : only load backends when required (ggml-org#22290)
* common : only load backends when required Signed-off-by: Adrien Gallouët <angt@huggingface.co> * llama : call ggml_backend_load_all() directly from llama_backend_init() Signed-off-by: Adrien Gallouët <angt@huggingface.co> * Add ggml_backend_load_all() where llama_backend_init() is not used Signed-off-by: Adrien Gallouët <angt@huggingface.co> --------- Signed-off-by: Adrien Gallouët <angt@huggingface.co>
1 parent a09a00e commit bf76ac7

6 files changed

Lines changed: 19 additions & 3 deletions

File tree

common/arg.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ std::vector<std::string> common_arg::get_env() const {
248248

249249
// Helper function to parse tensor buffer override strings
250250
static void parse_tensor_buffer_overrides(const std::string & value, std::vector<llama_model_tensor_buft_override> & overrides) {
251+
ggml_backend_load_all();
252+
251253
std::map<std::string, ggml_backend_buffer_type_t> buft_list;
252254
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
253255
auto * dev = ggml_backend_dev_get(i);
@@ -803,6 +805,7 @@ static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & val
803805
if (dev_names.size() == 1 && dev_names[0] == "none") {
804806
devices.push_back(nullptr);
805807
} else {
808+
ggml_backend_load_all();
806809
for (const auto & device : dev_names) {
807810
auto * dev = ggml_backend_dev_by_name(device.c_str());
808811
if (!dev || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
@@ -820,6 +823,7 @@ static void add_rpc_devices(const std::string & servers) {
820823
if (rpc_servers.empty()) {
821824
throw std::invalid_argument("no RPC servers specified");
822825
}
826+
ggml_backend_load_all();
823827
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
824828
if (!rpc_reg) {
825829
throw std::invalid_argument("failed to find RPC backend");
@@ -1016,9 +1020,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
10161020

10171021
params.use_color = tty_can_use_colors();
10181022

1019-
// load dynamic backends
1020-
ggml_backend_load_all();
1021-
10221023
common_params_context ctx_arg(params);
10231024
ctx_arg.print_usage = print_usage;
10241025
ctx_arg.ex = ex;
@@ -2275,6 +2276,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22752276
{"--list-devices"},
22762277
"print list of available devices and exit",
22772278
[](common_params &) {
2279+
ggml_backend_load_all();
22782280
std::vector<ggml_backend_dev_t> devices;
22792281
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
22802282
auto * dev = ggml_backend_dev_get(i);

examples/save-load-state/save-load-state.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ int main(int argc, char ** argv) {
4141
std::string result3;
4242

4343
// init
44+
45+
ggml_backend_load_all();
46+
4447
auto llama_init = common_init_from_params(params);
4548

4649
auto * model = llama_init->model();

src/llama.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ void llama_backend_init(void) {
8989
struct ggml_context * ctx = ggml_init(params);
9090
ggml_free(ctx);
9191
}
92+
93+
if (!ggml_backend_reg_count()) {
94+
ggml_backend_load_all();
95+
}
9296
}
9397

9498
void llama_numa_init(enum ggml_numa_strategy numa) {

tests/test-state-restore-fragmented.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ int main(int argc, char ** argv) {
2929
}
3030

3131
// init
32+
33+
ggml_backend_load_all();
34+
3235
common_init_result_ptr llama_init = common_init_from_params(params);
3336

3437
llama_model * model = llama_init->model();

tools/mtmd/debug/mtmd-debug.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ int main(int argc, char ** argv) {
6868
return 1;
6969
}
7070

71+
ggml_backend_load_all();
72+
7173
LOG_INF("%s: loading model: %s\n", __func__, params.model.path.c_str());
7274

7375
mtmd::context_ptr ctx_mtmd;

tools/mtmd/mtmd-cli.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ int main(int argc, char ** argv) {
295295
return 1;
296296
}
297297

298+
ggml_backend_load_all();
299+
298300
mtmd_cli_context ctx(params);
299301
LOG_INF("%s: loading model: %s\n", __func__, params.model.path.c_str());
300302

0 commit comments

Comments
 (0)