Skip to content

Commit 994118a

Browse files
authored
model: move load_hparams and load_tensors to per-model definition (ggml-org#22004)
* git-friendly migration * add build_graph * nits * exclude old code from build * wip * add llm_arch_model_i * prepare downstream functions * nits * nits * wip * wip * add back create_tensor_qkv * fix files missing include * enforce one llm_build per arch * cmake: use glob * missing model params * nits * wip * wip (2) * wip (3) * test-llama-archs is happy * improve switch case * move more stuff into llm_arch_model_i * fix downstream code * nits * nits (2) * fix order * llama_model_base * LLAMA_LOAD_LOCALS * small fix * fix build errors * auto * rm migration script and ifdef
1 parent c84e6d6 commit 994118a

129 files changed

Lines changed: 11431 additions & 8881 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/llama-model.cpp

Lines changed: 1115 additions & 8087 deletions
Large diffs are not rendered by default.

src/llama-model.h

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -577,14 +577,8 @@ struct llama_model {
577577
int64_t t_load_us = 0;
578578
int64_t t_start_us = 0;
579579

580-
explicit llama_model(const struct llama_model_params & params);
581-
~llama_model();
582-
583-
void load_stats (llama_model_loader & ml);
584-
void load_arch (llama_model_loader & ml);
585-
void load_hparams(llama_model_loader & ml);
586-
void load_vocab (llama_model_loader & ml);
587-
bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
580+
explicit llama_model(const llama_model_params & params);
581+
virtual ~llama_model();
588582

589583
std::string arch_name() const;
590584
std::string type_name() const;
@@ -620,21 +614,94 @@ struct llama_model {
620614

621615
ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
622616

623-
// TODO: move this to new llm_arch_model_i interface
624617
llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const;
625618

626-
// TODO: move this to new llm_arch_model_i interface
627619
ggml_cgraph * build_graph(const llm_graph_params & params) const;
628620

629-
private:
621+
virtual void load_stats (llama_model_loader & ml) = 0;
622+
virtual void load_hparams(llama_model_loader & ml) = 0;
623+
virtual void load_vocab (llama_model_loader & ml) = 0;
624+
virtual bool load_tensors(llama_model_loader & ml) = 0; // returns false if cancelled by progress_callback
625+
626+
// model must define these
627+
virtual void load_arch_hparams(llama_model_loader & ml) = 0;
628+
virtual void load_arch_tensors(llama_model_loader & ml) = 0;
629+
virtual std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const = 0;
630+
631+
protected:
630632
llama_model_params params;
631633

632634
struct impl;
633635
std::unique_ptr<impl> pimpl;
634636
};
635637

638+
llama_model * llama_model_create(llm_arch arch, const llama_model_params & params);
639+
llama_model * llama_model_create(llama_model_loader & ml, const llama_model_params & params);
640+
641+
// model must inherit from this
642+
struct llama_model_base : public llama_model {
643+
friend struct llama_model;
644+
645+
llama_model * model;
646+
llama_model_loader * ml = nullptr;
647+
const LLM_TN tn;
648+
649+
// llama_model_loader is not yet defined at this point, so we will set it after construction
650+
const int TENSOR_DUPLICATED;
651+
const int TENSOR_NOT_REQUIRED;
652+
const int TENSOR_SKIP;
653+
const int TENSOR_SKIP_IF_VIRTUAL;
654+
655+
explicit llama_model_base(const llama_model_params & params);
656+
virtual ~llama_model_base() = default;
657+
658+
ggml_tensor * create_tensor(llama_model_loader & ml, const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags);
659+
660+
// convenience overload of create_tensor that doesn't require llama_model_loader
661+
ggml_tensor * create_tensor(const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags);
662+
663+
// helper: try merged gate_up_exps first, fall back to separate gate and up
664+
void create_tensor_gate_up_exps(llama_layer & layer, int bid, int64_t n_embd_,
665+
int64_t n_ff_, int64_t n_expert_, int flags);
666+
667+
// helper: try to load merged qkv first, fall back to separate q, k, v
668+
void create_tensor_qkv(llama_layer & layer, int bid,
669+
int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_,
670+
int flags);
671+
672+
void load_stats (llama_model_loader & ml) override;
673+
void load_hparams(llama_model_loader & ml) override;
674+
void load_vocab (llama_model_loader & ml) override;
675+
bool load_tensors(llama_model_loader & ml) override;
676+
677+
// model must define these
678+
void load_arch_hparams(llama_model_loader & ml) override = 0;
679+
void load_arch_tensors(llama_model_loader & ml) override = 0;
680+
std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override = 0;
681+
};
682+
636683
const char * llm_type_name(llm_type type);
637684

685+
// convenience macro for loading local variables for load_tensors() in llama_model_base
686+
// note: cast to int64_t since we will use these for the tensor dimensions
687+
#define LLAMA_LOAD_LOCALS \
688+
const int n_layer = hparams.n_layer; GGML_UNUSED(n_layer); \
689+
const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \
690+
const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \
691+
const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \
692+
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); GGML_UNUSED(n_embd_k_gqa); \
693+
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); GGML_UNUSED(n_embd_v_gqa); \
694+
const int64_t n_embd_head_k = hparams.n_embd_head_k(); GGML_UNUSED(n_embd_head_k); \
695+
const int64_t n_embd_head_v = hparams.n_embd_head_v(); GGML_UNUSED(n_embd_head_v); \
696+
const int64_t n_ff = hparams.n_ff(); GGML_UNUSED(n_ff); \
697+
const int64_t n_embd_gqa = n_embd_v_gqa; GGML_UNUSED(n_embd_gqa); \
698+
const int64_t n_vocab = vocab.n_tokens(); GGML_UNUSED(n_vocab); \
699+
const int64_t n_token_types = vocab.n_token_types(); GGML_UNUSED(n_token_types); \
700+
const int64_t n_rot = hparams.n_rot(); GGML_UNUSED(n_rot); \
701+
const int64_t n_expert = hparams.n_expert; GGML_UNUSED(n_expert); \
702+
const int64_t n_expert_used = hparams.n_expert_used; GGML_UNUSED(n_expert_used); \
703+
const int64_t n_ctx_train = hparams.n_ctx_train; GGML_UNUSED(n_ctx_train);
704+
638705
// For internal test use
639706
// TODO: remove
640707
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model);

src/llama-quant.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -882,13 +882,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
882882
fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr);
883883
ml.init_mappings(false); // no prefetching
884884

885-
llama_model model(llama_model_default_params());
885+
auto mparams = llama_model_default_params();
886+
std::unique_ptr<llama_model> model_ptr(llama_model_create(ml, mparams));
886887

887-
model.load_arch (ml);
888-
model.load_hparams(ml);
889-
model.load_stats (ml);
888+
auto * model = dynamic_cast<llama_model_base *>(model_ptr.get());
889+
if (model == nullptr) {
890+
GGML_ABORT("fatal error: model does not implement llama_model_base");
891+
}
892+
893+
model->load_hparams(ml);
894+
model->load_stats (ml);
890895

891-
quantize_state_impl qs(model, params);
896+
quantize_state_impl qs(*model, params);
892897

893898
if (params->only_copy) {
894899
ftype = ml.ftype;
@@ -1023,7 +1028,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
10231028
}
10241029
gguf_add_tensor(ctx_outs[i_split].get(), tensor);
10251030

1026-
metadata[i].allows_quantization = tensor_allows_quantization(params, model.arch, tensor);
1031+
metadata[i].allows_quantization = tensor_allows_quantization(params, model->arch, tensor);
10271032

10281033
if (metadata[i].allows_quantization) {
10291034
metadata[i].target_type = llama_tensor_get_type(qs, params, tensor, default_type, metadata[i]);
@@ -1331,9 +1336,9 @@ void llama_quant_free(quantize_state_impl * qs) {
13311336

13321337
llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) {
13331338
struct llama_model_params mparams = llama_model_default_params();
1334-
auto * model = new llama_model(mparams);
1335-
1336-
model->arch = llm_arch_from_string(desc->architecture);
1339+
auto arch = llm_arch_from_string(desc->architecture);
1340+
auto * model = llama_model_create(arch, mparams);
1341+
model->arch = arch;
13371342

13381343
// infer llm_type: only LLM_TYPE_70B matters for quantization logic
13391344
if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) {

0 commit comments

Comments
 (0)