@@ -577,14 +577,8 @@ struct llama_model {
577577 int64_t t_load_us = 0 ;
578578 int64_t t_start_us = 0 ;
579579
580- explicit llama_model (const struct llama_model_params & params);
581- ~llama_model ();
582-
583- void load_stats (llama_model_loader & ml);
584- void load_arch (llama_model_loader & ml);
585- void load_hparams (llama_model_loader & ml);
586- void load_vocab (llama_model_loader & ml);
587- bool load_tensors (llama_model_loader & ml); // returns false if cancelled by progress_callback
580+ explicit llama_model (const llama_model_params & params);
581+ virtual ~llama_model ();
588582
589583 std::string arch_name () const ;
590584 std::string type_name () const ;
@@ -620,21 +614,94 @@ struct llama_model {
620614
621615 ggml_tensor * get_rope_factors (const llama_cparams & cparams, int il) const ;
622616
623- // TODO: move this to new llm_arch_model_i interface
624617 llama_memory_i * create_memory (const llama_memory_params & params, const llama_cparams & cparams) const ;
625618
626- // TODO: move this to new llm_arch_model_i interface
627619 ggml_cgraph * build_graph (const llm_graph_params & params) const ;
628620
629- private:
621+ virtual void load_stats (llama_model_loader & ml) = 0;
622+ virtual void load_hparams (llama_model_loader & ml) = 0;
623+ virtual void load_vocab (llama_model_loader & ml) = 0;
624+ virtual bool load_tensors (llama_model_loader & ml) = 0; // returns false if cancelled by progress_callback
625+
626+ // model must define these
627+ virtual void load_arch_hparams (llama_model_loader & ml) = 0;
628+ virtual void load_arch_tensors (llama_model_loader & ml) = 0;
629+ virtual std::unique_ptr<llm_graph_context> build_arch_graph (const llm_graph_params & params) const = 0;
630+
631+ protected:
630632 llama_model_params params;
631633
632634 struct impl ;
633635 std::unique_ptr<impl> pimpl;
634636};
635637
638+ llama_model * llama_model_create (llm_arch arch, const llama_model_params & params);
639+ llama_model * llama_model_create (llama_model_loader & ml, const llama_model_params & params);
640+
641+ // model must inherit from this
642+ struct llama_model_base : public llama_model {
643+ friend struct llama_model ;
644+
645+ llama_model * model;
646+ llama_model_loader * ml = nullptr ;
647+ const LLM_TN tn;
648+
649+ // llama_model_loader is not yet defined at this point, so we will set it after construction
650+ const int TENSOR_DUPLICATED ;
651+ const int TENSOR_NOT_REQUIRED ;
652+ const int TENSOR_SKIP ;
653+ const int TENSOR_SKIP_IF_VIRTUAL ;
654+
655+ explicit llama_model_base (const llama_model_params & params);
656+ virtual ~llama_model_base () = default ;
657+
658+ ggml_tensor * create_tensor (llama_model_loader & ml, const LLM_TN_IMPL & tn, const std::initializer_list<int64_t > & ne, int flags);
659+
660+ // convenience overload of create_tensor that doesn't require llama_model_loader
661+ ggml_tensor * create_tensor (const LLM_TN_IMPL & tn, const std::initializer_list<int64_t > & ne, int flags);
662+
663+ // helper: try merged gate_up_exps first, fall back to separate gate and up
664+ void create_tensor_gate_up_exps (llama_layer & layer, int bid, int64_t n_embd_,
665+ int64_t n_ff_, int64_t n_expert_, int flags);
666+
667+ // helper: try to load merged qkv first, fall back to separate q, k, v
668+ void create_tensor_qkv (llama_layer & layer, int bid,
669+ int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_,
670+ int flags);
671+
672+ void load_stats (llama_model_loader & ml) override ;
673+ void load_hparams (llama_model_loader & ml) override ;
674+ void load_vocab (llama_model_loader & ml) override ;
675+ bool load_tensors (llama_model_loader & ml) override ;
676+
677+ // model must define these
678+ void load_arch_hparams (llama_model_loader & ml) override = 0;
679+ void load_arch_tensors (llama_model_loader & ml) override = 0;
680+ std::unique_ptr<llm_graph_context> build_arch_graph (const llm_graph_params & params) const override = 0;
681+ };
682+
636683const char * llm_type_name (llm_type type);
637684
685+ // convenience macro for loading local variables for load_tensors() in llama_model_base
686+ // note: cast to int64_t since we will use these for the tensor dimensions
687+ #define LLAMA_LOAD_LOCALS \
688+ const int n_layer = hparams.n_layer; GGML_UNUSED (n_layer); \
689+ const int64_t n_head = hparams.n_head(); GGML_UNUSED (n_head); \
690+ const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED (n_head_kv); \
691+ const int64_t n_embd = hparams.n_embd; GGML_UNUSED (n_embd); \
692+ const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); GGML_UNUSED (n_embd_k_gqa); \
693+ const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); GGML_UNUSED (n_embd_v_gqa); \
694+ const int64_t n_embd_head_k = hparams.n_embd_head_k(); GGML_UNUSED (n_embd_head_k); \
695+ const int64_t n_embd_head_v = hparams.n_embd_head_v(); GGML_UNUSED (n_embd_head_v); \
696+ const int64_t n_ff = hparams.n_ff(); GGML_UNUSED (n_ff); \
697+ const int64_t n_embd_gqa = n_embd_v_gqa; GGML_UNUSED (n_embd_gqa); \
698+ const int64_t n_vocab = vocab.n_tokens(); GGML_UNUSED (n_vocab); \
699+ const int64_t n_token_types = vocab.n_token_types(); GGML_UNUSED (n_token_types); \
700+ const int64_t n_rot = hparams.n_rot(); GGML_UNUSED (n_rot); \
701+ const int64_t n_expert = hparams.n_expert; GGML_UNUSED (n_expert); \
702+ const int64_t n_expert_used = hparams.n_expert_used; GGML_UNUSED (n_expert_used); \
703+ const int64_t n_ctx_train = hparams.n_ctx_train; GGML_UNUSED (n_ctx_train);
704+
638705// For internal test use
639706// TODO: remove
640707const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map (const llama_model * model);
0 commit comments