Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions ci/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,10 @@ function gg_run_open_llama_7b_v2 {

(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log

(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log

function check_ppl {
qnt="$1"
Expand Down Expand Up @@ -520,8 +520,8 @@ function gg_run_pythia_1_4b {

(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log

(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log

function check_ppl {
qnt="$1"
Expand Down Expand Up @@ -651,10 +651,10 @@ function gg_run_pythia_2_8b {

(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log

(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log

function check_ppl {
qnt="$1"
Expand Down
14 changes: 7 additions & 7 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2962,20 +2962,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.endpoint_metrics = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS"));
add_opt(common_arg(
{"--slots"},
string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
[](common_params & params) {
params.endpoint_slots = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
add_opt(common_arg(
{"--props"},
string_format("enable changing global properties via POST /props (default: %s)", params.endpoint_props ? "enabled" : "disabled"),
[](common_params & params) {
params.endpoint_props = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_PROPS"));
add_opt(common_arg(
{"--slots"},
string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
[](common_params & params) {
params.endpoint_slots = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
add_opt(common_arg(
{"--no-slots"},
"disables slots monitoring endpoint",
Expand Down
2 changes: 1 addition & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ struct common_params {

// "advanced" endpoints are disabled by default for better security
bool webui = true;
bool endpoint_slots = false;
bool endpoint_slots = true;
bool endpoint_props = false; // only control POST requests, not GET
bool endpoint_metrics = false;

Expand Down
25 changes: 23 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,29 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

// helpers

llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) {
return &gsmpl->cur_p;
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
auto * res = &gsmpl->cur_p;

if (do_sort && !res->sorted) {
// remember the selected token before sorting
const llama_token id = res->data[res->selected].id;

std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.p > b.p;
});

// restore the selected token after sorting
for (size_t i = 0; i < res->size; ++i) {
if (res->data[i].id == id) {
res->selected = i;
break;
}
}

res->sorted = true;
}

return res;
}

llama_token common_sampler_last(const struct common_sampler * gsmpl) {
Expand Down
4 changes: 3 additions & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// helpers

// access the internal list of current candidate tokens
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
// if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
// the .sorted flag of the result indicates whether the returned candidates are sorted
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);

// get the last accepted token
llama_token common_sampler_last(const struct common_sampler * gsmpl);
Expand Down
2 changes: 1 addition & 1 deletion common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ llama_tokens common_speculative_gen_draft(

common_sampler_sample(smpl, ctx_dft, 0, true);

const auto * cur_p = common_sampler_get_candidates(smpl);
const auto * cur_p = common_sampler_get_candidates(smpl, true);

for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ int main(int argc, char ** argv) {
// stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);

auto & dist_tgt = *common_sampler_get_candidates(smpl);
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);

float p_tgt = 0.0f;
float p_dft = 0.0f;
Expand Down Expand Up @@ -493,7 +493,7 @@ int main(int argc, char ** argv) {

common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);

const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);

for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
Expand Down
3 changes: 3 additions & 0 deletions ggml/include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ extern "C" {
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);

// Split graph without allocating it
GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);

// Allocate and compute graph on the backend scheduler
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
Expand Down
4 changes: 3 additions & 1 deletion ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
}

// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
// reset splits
sched->n_splits = 0;
sched->n_graph_inputs = 0;
Expand Down Expand Up @@ -1687,6 +1687,8 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
GGML_ASSERT(sched);
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);

ggml_backend_sched_reset(sched);

ggml_backend_sched_synchronize(sched);

ggml_backend_sched_split_graph(sched, measure_graph);
Expand Down
Loading
Loading