Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
option(LLAMA_BUILD_COMMON "llama: build common utils library" ${LLAMA_STANDALONE})

# extra artifacts
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_UI "llama: build the embedded Web UI for server" ON)
option(LLAMA_USE_PREBUILT_UI "llama: use prebuilt UI from HF Bucket when available (requires LLAMA_BUILD_UI=ON)" ON)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_APP "llama: build the unified binary" OFF)
option(LLAMA_BUILD_UI "llama: build the embedded Web UI for server" ON)
option(LLAMA_USE_PREBUILT_UI "llama: use prebuilt UI from HF Bucket when available (requires LLAMA_BUILD_UI=ON)" ON)

# Backward compat: when old var is set but new one isn't, forward the value
if(DEFINED LLAMA_BUILD_WEBUI)
Expand All @@ -120,8 +121,9 @@ if(DEFINED LLAMA_USE_PREBUILT_WEBUI)
set(LLAMA_USE_PREBUILT_UI ${LLAMA_USE_PREBUILT_WEBUI})
message(DEPRECATION "LLAMA_USE_PREBUILT_WEBUI is deprecated, use LLAMA_USE_PREBUILT_UI instead")
endif()
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)

option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)

# 3rd party libs
option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON)
Expand Down Expand Up @@ -226,6 +228,10 @@ if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TOOLS)
add_subdirectory(tools)
endif()

if (LLAMA_BUILD_APP)
add_subdirectory(app)
endif()

# Automatically add all files from the 'licenses' directory
file(GLOB EXTRA_LICENSES "${CMAKE_SOURCE_DIR}/licenses/LICENSE-*")

Expand Down
11 changes: 11 additions & 0 deletions app/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
set(TARGET llama-app)

add_executable(${TARGET} llama.cpp)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama)

target_link_libraries(${TARGET} PRIVATE llama-server-impl llama-cli-impl llama-completion-impl llama-bench-impl)
target_compile_features(${TARGET} PRIVATE cxx_std_17)

if(LLAMA_TOOLS_INSTALL)
install(TARGETS ${TARGET} RUNTIME)
endif()
67 changes: 67 additions & 0 deletions app/llama.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include <cstdio>
#include <string>
#include <vector>

int llama_server(int argc, char ** argv);
int llama_cli(int argc, char ** argv);

// hidden
int llama_completion(int argc, char ** argv);
int llama_bench(int argc, char ** argv);
static int help(int argc, char ** argv);

struct command {
const char * name;
const char * desc;
std::vector<std::string> aliases;
bool hidden;
int (*func)(int, char **);
};

static const command cmds[] = {
{"serve", "HTTP API server", {"server"}, false, llama_server },
{"cli", "Command-line interactive interface", {"client"}, false, llama_cli },
{"completion", "Text completion", {"complete"}, true, llama_completion },
{"bench", "Benchmarking tool", {}, true, llama_bench },
{"help", "Show available commands", {}, true, help },
};

static int help(int argc, char ** argv) {
const bool show_all = argc >= 2 && std::string(argv[1]) == "all";

printf("Usage: llama <command> [options]\n\nAvailable commands:\n");

for (const auto & cmd : cmds) {
if (show_all || !cmd.hidden) {
printf(" %-15s %s\n", cmd.name, cmd.desc);
}
}
printf("\nRun 'llama <command> --help' for command-specific usage.\n");

return 0;
}

static bool matches(const std::string & arg, const command & cmd) {
if (arg == cmd.name) {
return true;
}
for (const auto & alias : cmd.aliases) {
if (arg == alias) {
return true;
}
}
return false;
}

int main(int argc, char ** argv) {
const std::string arg = argc >= 2 ? argv[1] : "help";

for (const auto & cmd : cmds) {
if (matches(arg, cmd)) {
return cmd.func(argc - 1, argv + 1);
}
}

fprintf(stderr, "error: unknown command '%s'\n", arg.c_str());
return 1;
}
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
# 80 == Ampere, asynchronous data loading, faster tensor core instructions
# 86 == RTX 3000, needs CUDA v11.1
# 89 == RTX 4000, needs CUDA v11.8
# 90 == Hopper H100/200, needs CUDA v11.8
# 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
#
# XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
Expand All @@ -33,7 +34,7 @@ if (CUDAToolkit_FOUND)
list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real)

if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-virtual)
endif()

if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
Expand Down
32 changes: 14 additions & 18 deletions ggml/src/ggml-cuda/binbcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include <cstdint>
#include <utility>

template<typename T, size_t>
using type_for_index = T;

static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
GGML_UNUSED(a);
Expand Down Expand Up @@ -52,6 +55,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const int s12,
const int s13,
src1_ptrs... src1s) {
ggml_cuda_pdl_lc();
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
Expand All @@ -72,6 +76,7 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst;

ggml_cuda_pdl_sync();
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10);

Expand Down Expand Up @@ -141,6 +146,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,

const int i10 = fastmodulo(i0, ne10);

ggml_cuda_pdl_sync();
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
Expand Down Expand Up @@ -282,35 +288,24 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);

if constexpr (sizeof...(I) > 0) {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
{
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)block_num, block_size, 0, stream);
ggml_cuda_kernel_launch(k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params,
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
{
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream);
ggml_cuda_kernel_launch(k_bin_bcast<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params,
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00 ,s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
}
}
}
Expand All @@ -333,6 +328,7 @@ static __global__ void k_repeat_back(
}

T sum = 0;
ggml_cuda_pdl_sync();
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
Expand Down
86 changes: 86 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "ggml-cuda.h"

#include <cstdint>
#include <cstdlib>
#include <memory>

#if defined(GGML_USE_HIP)
Expand All @@ -27,6 +28,7 @@
#include <cstdio>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#if defined(GGML_USE_HIP)
Expand All @@ -50,6 +52,7 @@
#define GGML_CUDA_CC_TURING 750
#define GGML_CUDA_CC_AMPERE 800
#define GGML_CUDA_CC_ADA_LOVELACE 890
#define GGML_CUDA_CC_HOPPER 900
// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
#define GGML_CUDA_CC_BLACKWELL 1200
Expand Down Expand Up @@ -107,6 +110,24 @@
# define GGML_CUDA_USE_CUB
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070

// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA.
// __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code.
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080
# define GGML_CUDA_USE_PDL
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080

static __device__ __forceinline__ void ggml_cuda_pdl_sync() {
#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
cudaGridDependencySynchronize();
#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
}

static __device__ __forceinline__ void ggml_cuda_pdl_lc() {
#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
cudaTriggerProgrammaticLaunchCompletion();
#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER
}

#ifdef __CUDA_ARCH_LIST__
constexpr bool ggml_cuda_has_arch_impl(int) {
return false;
Expand Down Expand Up @@ -165,6 +186,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in

#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)


#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
static const char * cublas_get_error_str(const cublasStatus_t err) {
return cublasGetStatusString(err);
Expand Down Expand Up @@ -1487,3 +1509,67 @@ struct ggml_cuda_mm_fusion_args_device {
const void * gate_bias = nullptr;
ggml_glu_op glu_op;
};

struct ggml_cuda_kernel_launch_params {
dim3 block_nums;
dim3 block_dims;
size_t shmem;
cudaStream_t stream;

// size_t shmem
ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const size_t shmem_, const cudaStream_t stream_)
: block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {}

// Some call sites pass ints instead of the required size_t. This 2nd constructor casts int->size_t to avoid these -Wnarrowing warnings.
ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, const cudaStream_t stream_)
: block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {}
};

#if defined(GGML_CUDA_USE_PDL)
struct ggml_cuda_pdl_config {
cudaLaunchAttribute attr;
cudaLaunchConfig_t cfg;

ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) {
attr.id = cudaLaunchAttributeProgrammaticStreamSerialization;
attr.val.programmaticStreamSerializationAllowed = 1;

cfg = {};
cfg.gridDim = params.block_nums;
cfg.blockDim = params.block_dims;
cfg.dynamicSmemBytes = params.shmem;
cfg.stream = params.stream;
cfg.attrs = &attr;
cfg.numAttrs = 1;
}

// Delete due to &attr
ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete;
ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete;
ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete;

};
#endif //defined(GGML_CUDA_USE_PDL)


template<typename Kernel, typename... Args>
static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) {
#if defined(GGML_CUDA_USE_PDL)

static const bool env_pdl_enabled = []() {
const char * env = getenv("GGML_CUDA_PDL");
return env == nullptr || std::atoi(env) != 0;
}();

if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) {
auto pdl_cfg = ggml_cuda_pdl_config(launch_params);

CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... ));
return;
}
#endif //defined(GGML_CUDA_USE_PDL)

kernel<<<launch_params.block_nums, launch_params.block_dims, launch_params.shmem, launch_params.stream>>>(std::forward<Args>(args)... );
CUDA_CHECK(cudaGetLastError());
}

5 changes: 3 additions & 2 deletions ggml/src/ggml-cuda/concat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont

const int64_t n = ne0 * ne1 * ne2;

ggml_cuda_pdl_sync();
for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) {
if constexpr (dim == 0) {
const int64_t row = i / ne0;
Expand Down Expand Up @@ -64,8 +65,8 @@ static void concat_f32_cuda(const float * x,
const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;

if (dim == 0) {
concat_f32_cont<0>
<<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream);
ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2);
return;
}
if (dim == 1) {
Expand Down
Loading
Loading