Skip to content

Commit 5803c8d

Browse files
authored
tests: allow exporting graph ops from HF file without downloading weights (ggml-org#21182)
* tests: allow exporting graph ops from HF file without downloading weights * use unique_ptr for llama_context in HF metadata case * fix missing non-required tensors falling back to type f32 * use unique pointers where possible * use no_alloc instead of fixing f32 fallback * fix missing space
1 parent 63f8fe0 commit 5803c8d

7 files changed

Lines changed: 169 additions & 12 deletions

File tree

common/arg.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
537537
} catch (const std::exception & e) {
538538
LOG_WRN("HF cache migration failed: %s\n", e.what());
539539
}
540+
// export_graph_ops loads only metadata
541+
const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS;
540542

541543
// maybe handle remote preset
542-
if (!params.model.hf_repo.empty()) {
544+
if (!params.model.hf_repo.empty() && !skip_model_download) {
543545
std::string cli_hf_repo = params.model.hf_repo;
544546
bool has_preset = common_params_handle_remote_preset(params, ctx_arg.ex);
545547

@@ -570,7 +572,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
570572
}
571573

572574
// handle model and download
573-
{
575+
if (!skip_model_download) {
574576
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
575577
if (params.no_mmproj) {
576578
params.mmproj = {};
@@ -591,7 +593,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
591593

592594
// model is required (except for server)
593595
// TODO @ngxson : maybe show a list of available models in CLI in this case
594-
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) {
596+
if (params.model.path.empty() && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !skip_model_download && !params.usage && !params.completion) {
595597
throw std::invalid_argument("error: --model is required\n");
596598
}
597599

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
14421442

14431443
mparams.progress_callback = params.load_progress_callback;
14441444
mparams.progress_callback_user_data = params.load_progress_callback_user_data;
1445+
mparams.no_alloc = params.no_alloc;
14451446

14461447
return mparams;
14471448
}

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ struct common_params {
679679
// return false from callback to abort model loading or true to continue
680680
llama_progress_callback load_progress_callback = NULL;
681681
void * load_progress_callback_user_data = NULL;
682+
bool no_alloc = false; // Don't allocate model buffers
682683
};
683684

684685
// call once at the start of a program if it uses libcommon

tests/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,7 @@ target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
287287

288288
llama_build(export-graph-ops.cpp)
289289
target_include_directories(export-graph-ops PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src)
290+
if (TARGET gguf-model-data)
291+
target_link_libraries(export-graph-ops PRIVATE gguf-model-data)
292+
target_compile_definitions(export-graph-ops PRIVATE LLAMA_HF_FETCH)
293+
endif()

tests/export-graph-ops.cpp

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
#include "arg.h"
22
#include "common.h"
33
#include "log.h"
4-
#include "llama.h"
4+
#include "llama-cpp.h"
55
#include "../src/llama-ext.h"
66
#include "ggml.h"
7+
#include "gguf-model-data.h"
8+
#include "gguf.h"
9+
#include "ggml-backend.h"
10+
#include "download.h"
711

812
#include <array>
913
#include <vector>
1014
#include <set>
1115
#include <fstream>
1216
#include <iostream>
17+
#include <random>
18+
19+
// Noop because weights are not needed
20+
static void set_tensor_data(struct ggml_tensor * tensor, void * userdata) {
21+
GGML_UNUSED(tensor);
22+
GGML_UNUSED(userdata);
23+
}
1324

1425
struct input_tensor {
1526
ggml_type type;
@@ -132,9 +143,52 @@ int main(int argc, char ** argv) {
132143

133144
params.warmup = false;
134145

135-
auto init_result = common_init_from_params(params);
146+
llama_context * ctx;
147+
common_init_result_ptr init_result;
148+
llama_context_ptr ctx2;
149+
llama_model_ptr model;
150+
151+
if (params.model.hf_repo.empty()) {
152+
init_result = common_init_from_params(params);
153+
154+
ctx = init_result->context();
155+
} else {
156+
#ifdef LLAMA_HF_FETCH
157+
auto [hf_repo, hf_quant] = common_download_split_repo_tag(params.model.hf_repo);
158+
if (hf_quant.empty() || hf_quant == "latest") {
159+
hf_quant = "Q4_K_M";
160+
}
161+
162+
gguf_context_ptr gguf_ctx = gguf_fetch_gguf_ctx(hf_repo, hf_quant);
163+
if (!gguf_ctx) {
164+
LOG_ERR("failed to fetch GGUF metadata from %s\n", hf_repo.c_str());
165+
return 1;
166+
}
167+
168+
llama_model_params model_params = llama_model_default_params();
169+
model_params.devices = params.devices.data();
170+
model_params.no_alloc = true;
171+
172+
model.reset(llama_model_init_from_user(gguf_ctx.get(), set_tensor_data, nullptr, model_params));
136173

137-
llama_context * ctx = init_result->context();
174+
if (!model) {
175+
LOG_ERR("failed to create llama_model from %s\n", hf_repo.c_str());
176+
return 1;
177+
}
178+
179+
llama_context_params ctx_params = llama_context_default_params();
180+
ctx2.reset(llama_init_from_model(model.get(), ctx_params));
181+
ctx = ctx2.get();
182+
183+
if (!ctx) {
184+
LOG_ERR("failed to create llama_context\n");
185+
return 1;
186+
}
187+
#else
188+
LOG_ERR("export-graph-ops compiled without HF fetch support\n");
189+
return 1;
190+
#endif
191+
}
138192

139193
const uint32_t n_seqs = llama_n_seq_max(ctx);
140194
const uint32_t n_tokens = std::min(llama_n_ctx(ctx), llama_n_ubatch(ctx));
@@ -143,13 +197,15 @@ int main(int argc, char ** argv) {
143197

144198
auto * gf_pp = llama_graph_reserve(ctx, n_tokens, n_seqs, n_tokens);
145199
if (!gf_pp) {
146-
throw std::runtime_error("failed to reserve prompt processing graph");
200+
LOG_ERR("failed to reserve prompt processing graph\n");
201+
return 1;
147202
}
148203
extract_graph_ops(gf_pp, "pp", tests);
149204

150205
auto * gf_tg = llama_graph_reserve(ctx, n_seqs, n_seqs, n_seqs);
151206
if (!gf_tg) {
152-
throw std::runtime_error("failed to reserve token generation graph");
207+
LOG_ERR("failed to reserve token generation graph\n");
208+
return 1;
153209
}
154210
extract_graph_ops(gf_tg, "tg", tests);
155211

@@ -158,7 +214,8 @@ int main(int argc, char ** argv) {
158214
std::ofstream f(params.out_file);
159215

160216
if (!f.is_open()) {
161-
throw std::runtime_error("Unable to open output file");
217+
LOG_ERR("unable to open output file: %s\n", params.out_file.c_str());
218+
return 1;
162219
}
163220

164221
for (const auto& test : tests) {

tests/gguf-model-data.cpp

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "gguf-model-data.h"
55

66
#include "common.h"
7+
#include "ggml-cpp.h"
78
#include "gguf.h"
89

910
#include <algorithm>
@@ -531,14 +532,18 @@ static std::optional<gguf_remote_model> fetch_and_parse(
531532
return std::nullopt;
532533
}
533534

535+
static std::string get_cache_file_path(const std::string& cdir, const std::string& repo_part, const std::string& filename) {
536+
std::string fname_part = sanitize_for_path(filename);
537+
return cdir + "/" + repo_part + "--" + fname_part + ".partial";
538+
}
539+
534540
// Try cache first, then fetch and parse a single GGUF shard.
535541
static std::optional<gguf_remote_model> fetch_or_cached(
536542
const std::string & repo,
537543
const std::string & filename,
538544
const std::string & cdir,
539545
const std::string & repo_part) {
540-
std::string fname_part = sanitize_for_path(filename);
541-
std::string cache_path = cdir + "/" + repo_part + "--" + fname_part + ".partial";
546+
std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
542547

543548
{
544549
std::vector<char> cached;
@@ -611,3 +616,84 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
611616

612617
return model_opt;
613618
}
619+
620+
gguf_context_ptr gguf_fetch_gguf_ctx(
621+
const std::string & repo,
622+
const std::string & quant,
623+
const std::string & cache_dir) {
624+
std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir;
625+
std::string repo_part = sanitize_for_path(repo);
626+
627+
std::string split_prefix;
628+
std::string filename = detect_gguf_filename(repo, quant, split_prefix);
629+
630+
if (filename.empty()) {
631+
return nullptr;
632+
}
633+
634+
auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part);
635+
if (!model_opt.has_value()) {
636+
fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str());
637+
return nullptr;
638+
}
639+
640+
auto & model = model_opt.value();
641+
642+
const std::string cache_path = get_cache_file_path(cdir, repo_part, filename);
643+
644+
ggml_context_ptr ggml_ctx_ptr;
645+
ggml_context * ggml_ctx{};
646+
gguf_init_params params{true, &ggml_ctx};
647+
gguf_context_ptr ctx{gguf_init_from_file(cache_path.c_str(), params)};
648+
ggml_ctx_ptr.reset(ggml_ctx);
649+
650+
if (ctx == nullptr) {
651+
fprintf(stderr, "gguf_fetch: gguf_init_from_file failed\n");
652+
return nullptr;
653+
}
654+
655+
// If the model is split across multiple files we need to fetch the remaining shards metadata
656+
if (model.n_split > 1) {
657+
if (split_prefix.empty()) {
658+
fprintf(stderr, "gguf_fetch: model reports %u splits but filename has no split pattern\n", model.n_split);
659+
return nullptr;
660+
}
661+
662+
fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n",
663+
model.n_split, model.n_split - 1);
664+
665+
for (int i = 2; i <= model.n_split; i++) {
666+
char num_buf[6], total_buf[6];
667+
snprintf(num_buf, sizeof(num_buf), "%05d", i);
668+
snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split);
669+
std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf";
670+
671+
auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part);
672+
if (!shard.has_value()) {
673+
fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str());
674+
return nullptr;
675+
}
676+
677+
// Load tensors from shard and add to main gguf_context
678+
const std::string shard_path = get_cache_file_path(cdir, repo_part, shard_name);
679+
ggml_context_ptr shard_ggml_ctx_ptr;
680+
ggml_context * shard_ggml_ctx{};
681+
gguf_init_params shard_params{true, &shard_ggml_ctx};
682+
gguf_context_ptr shard_ctx{gguf_init_from_file(shard_path.c_str(), shard_params)};
683+
shard_ggml_ctx_ptr.reset(shard_ggml_ctx);
684+
685+
if (shard_ctx == nullptr) {
686+
fprintf(stderr, "gguf_fetch: shard gguf_init_from_file failed\n");
687+
return nullptr;
688+
}
689+
690+
for (ggml_tensor * t = ggml_get_first_tensor(shard_ggml_ctx); t; t = ggml_get_next_tensor(shard_ggml_ctx, t)) {
691+
gguf_add_tensor(ctx.get(), t);
692+
}
693+
}
694+
695+
gguf_set_val_u16(ctx.get(), "split.count", 1);
696+
}
697+
698+
return ctx;
699+
}

tests/gguf-model-data.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

3-
#include "ggml.h"
3+
#include "ggml-cpp.h"
4+
#include "gguf.h"
45

56
#include <cstdint>
67
#include <optional>
@@ -40,3 +41,8 @@ std::optional<gguf_remote_model> gguf_fetch_model_meta(
4041
const std::string & repo,
4142
const std::string & quant = "Q8_0",
4243
const std::string & cache_dir = ""); // empty = default
44+
45+
gguf_context_ptr gguf_fetch_gguf_ctx(
46+
const std::string & repo,
47+
const std::string & quant = "Q8_0",
48+
const std::string & cache_dir = "");

0 commit comments

Comments
 (0)