Skip to content

Commit cfe9838

Browse files
authored
fit-params : refactor + add option to output estimated memory per device (#22171)
* fit-params : add option to output estimated memory per device * cont : minor * cont : refactor * cont : move fit params implementation to libcommon * cont : header * cont : headers * cont : codeowners
1 parent ff6b106 commit cfe9838

19 files changed

Lines changed: 1123 additions & 980 deletions

File tree

CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
/ci/ @ggerganov
2424
/cmake/ @ggerganov
2525
/common/ @ggml-org/llama-common
26+
/common/fit.* @JohannesGaessler
2627
/common/jinja/ @CISC
2728
/common/ngram-map.* @srogmann
2829
/convert_*.py @CISC

common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ add_library(${TARGET}
7373
debug.h
7474
download.cpp
7575
download.h
76+
fit.cpp
77+
fit.h
7678
hf-cache.cpp
7779
hf-cache.h
7880
http.h

common/arg.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2426,6 +2426,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
24262426
}
24272427
}
24282428
).set_env("LLAMA_ARG_FIT"));
2429+
add_opt(common_arg(
2430+
{ "-fitp", "--fit-print" }, "[on|off]",
2431+
string_format("print the estimated required memory ('on' or 'off', default: '%s')", params.fit_params_print ? "on" : "off"),
2432+
[](common_params & params, const std::string & value) {
2433+
if (is_truthy(value)) {
2434+
params.fit_params_print = true;
2435+
} else if (is_falsey(value)) {
2436+
params.fit_params_print = false;
2437+
} else {
2438+
throw std::runtime_error(
2439+
string_format("error: unknown value for --fit-print: '%s'\n", value.c_str()));
2440+
}
2441+
}
2442+
).set_examples({LLAMA_EXAMPLE_FIT_PARAMS}).set_env("LLAMA_ARG_FIT_ESTIMATE"));
24292443
add_opt(common_arg(
24302444
{ "-fitt", "--fit-target" }, "MiB0,MiB1,MiB2,...",
24312445
string_format("target margin per device for --fit, comma-separated list of values, "

common/common.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "build-info.h"
55
#include "common.h"
6+
#include "fit.h"
67
#include "log.h"
78
#include "llama.h"
89
#include "sampling.h"
@@ -1147,7 +1148,7 @@ common_init_result::common_init_result(common_params & params) :
11471148

11481149
if (params.fit_params) {
11491150
LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__);
1150-
llama_params_fit(params.model.path.c_str(), &mparams, &cparams,
1151+
common_fit_params(params.model.path.c_str(), &mparams, &cparams,
11511152
params.tensor_split,
11521153
params.tensor_buft_overrides.data(),
11531154
params.fit_params_target.data(),

common/common.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -420,11 +420,12 @@ struct common_params {
420420
// offload params
421421
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
422422

423-
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
424-
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
425-
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
426-
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
427-
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
423+
int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all
424+
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
425+
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
426+
bool fit_params = true; // whether to fit unset model/context parameters to free device memory
427+
bool fit_params_print = false; // print the estimated required memory to run the model
428+
int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use
428429

429430
// margin per device in bytes for fitting parameters to free memory:
430431
std::vector<size_t> fit_params_target = std::vector<size_t>(llama_max_devices(), 1024 * 1024*1024);

0 commit comments

Comments
 (0)