You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have prepared a brief summary of the work I’ve done on a branch where I explored the idea of refactoring the ggml-cuda backend. I was surprised, how tightly coupled everything is there—almost all of the logic is concentrated within a single *.cu and *.cuh file. I analyzed the feasibility of adopting a non-monolithic architecture, the goal of which is to separate responsibilities across distinct modules (files). Below, I have included a summary outlining the most significant changes.
Few questions to you: is this concept valuable enough to submitting it as a PR and asking you to invest your time in reviewing it?
Maybe the monolithic-design is expected ? If yes, than this is clear that this is not expected and worth to follow.
TL;DR
refactor-backend is a CUDA-backend refactor branch that reorganises
the ggml/src/ggml-cuda backend by breaking the monolithic common.cuh and ggml-cuda.cu into a set of focused modules, introduces a new lightweight
per-op execution context, and adds developer documentation.
What the refactor aims
Add separation of concerns. Separate *.cu translation units are coupled only to necessary parts. For instance changes in graph.cu responsible for graph processing triggers the compiler to build only a single object not the whole back-end with all kernel-ops. In other words the compilation/rebuild time of backend is reduced
Add encapsulation, not everyone, everywhere can modify the cuda-driver state/entities ( cuda-stream, cugraph_t ...). Improved RAII (Resource Acquisition Is Initialization).
Functional behaviour of CUDA ops is intended to be unchanged; the bulk of
the line count comes from a mechanical signature/include rewrite plus moving
code between files. The existing interfaces in the ggml-cuda.cu and common.cuh are the same not modified only distributed to separate files.
Increase accessibility to newcomers
Increase encapsulation/decrease risk due to clearer separation of concern
Faster build times/dev UX
Risk of lower extensibility in future (i.e. if we separate something and may need it later on its going to be more complicated)
Which of the above design-points are expected and worth to merge/follow ?
What the refactor does (high-level)
Splits the giant ggml-cuda.cu (~5.3k lines removed from it) into per-concern
translation units: device.cu, buffer.cu, graph.cu, mul-mat.cu, compute-forward.cu, fusion.cu, multi-device-nccl.cu, registry.cu, concurrent-event.cu.
Removes common.cuh and replaces it with several smaller, purpose-built
headers: cuda-defs.cuh, kernel-utils.cuh, kernel-ops-context.cuh, backend-context.cuh, cuda-pool.cuh, concurrent-event.cuh, graph.cuh, type-traits.cuh, mmq-types.cuh.
Introduces ggml_cuda_kernel_ops_context - a small, const-friendly view
(device id, stream no, cudaStream_t, cublasHandle_t, optional pool
access) that replaces ggml_backend_cuda_context & in every CUDA op entry
point. This is why ~125 *.cu / *.cuh op files have a small, mechanical
change. They all need read-only stream or device_id not the whole context.
Adds CUDA backend documentation (mermaid diagrams) under ggml/src/ggml-cuda/docs/.
CUDA backend — new modules split out ofggml-cuda.cu and common.cuh
ggml/src/ggml-cuda/device.cu / device.cuh ->(Logic and entities associated with ggml_backend_cuda_device)
ggml/src/ggml-cuda/buffer.cu / buffer.cuh ->(Logic and entities associated with ggml_backend_cuda_buffer)
ggml/src/ggml-cuda/graph.cu / graph.cuh ->(Logic and entities associated for graph processing and caching)
ggml/src/ggml-cuda/mul-mat.cu / mul-mat.cuh -> ( Orchestration for mul-mat associated to ggml_cuda_op_mul_mat)
ggml/src/ggml-cuda/compute-forward.cu / compute-forward.cuh ->( Code for dispatch of ops)
ggml/src/ggml-cuda/fusion.cu / fusion.cuh -> ( Logic and entities associated with fusion of ops)
ggml/src/ggml-cuda/multi-device-nccl.cu / multi-device-nccl.cuh ( Logic and entities associated with gml_backend_cuda_comm_context and NCCL support)
ggml/src/ggml-cuda/concurrent-event.cu / concurrent-event.cuh -> ( ggml_cuda_concurrent_event struct that encapsulates cudaEvent_t and associated logic)
ggml/src/ggml-cuda/registry.cu -> ( Code and logic responsible for backend-registration)
ggml/src/ggml-cuda/backend-context.cuh -> ( Main entity that includes and encapsulated graphs, streams and other related to driver objects)
ggml/src/ggml-cuda/kernel-ops-context.cuh
ggml/src/ggml-cuda/kernel-utils.cuh
ggml/src/ggml-cuda/cuda-defs.cuh
ggml/src/ggml-cuda/cuda-pool.cuh
ggml/src/ggml-cuda/type-traits.cuh
ggml/src/ggml-cuda/mmq-types.cuh
Important: The content of these files above is only moved from ggml-cuda.cu and common.cuh, NOT new created or redesigned. Only private and public type scopes are added.
and the corresponding *.cuh swaps #include "common.cuh" for the smaller kernel-ops-context.cuh (or kernel-utils.cuh where needed). The same shape
of change applies to (modified) header/source pairs:
(Full list visible via git diff --name-status master...refactor-backend.)
Notable design points introduced
ggml_cuda_kernel_ops_context (kernel-ops-context.cuh) is a small
POD-like view holding device, stream_no, cudaStream_t, cublasHandle_t, and a back-pointer to the owning ggml_backend_cuda_context for pool access. Op functions take it by const &, which is what enables the const-correctness changes throughout
the backend.
ggml_backend_cuda_context (backend-context.cuh) becomes the heavy
owner: streams, cuBLAS handles, CUDA-graph cache, concurrent-event tracking,
device id, scratch pools. Forward declarations of ggml_cuda_graph / ggml_cuda_graph_cache keep the include surface small.
This diagram shows every .cuh header file and its direct #include
dependencies on other .cuh files within the ggml-cuda directory. Files are
grouped by functional area.
MMA tiles, vec dot, type traits, quantize, convert, copy helpers
Matmul Headers
6 .cuh
MMQ, MMF, MMVQ, MMVF, MMID, mul-mat
Flash Attention
6 .cuh
Dispatch + 4 paths + common
Operator Headers
51 .cuh
Most include kernel-ops-context.cuh; cpy, mean, and set use the full backend context
Fusion / internal AR
2 .cuh
snake.cuh (fused matmul helper); allreduce.cuh (host-side internal AR API)
Total
86 .cuh
Current non-template CUDA headers in ggml/src/ggml-cuda/
Top-level .cu
72
Implementation files outside template-instances/
Template instances .cu
119
Auto-generated type specializations
Grand total .cu
191
Standalone Headers
type-traits.cuh has no #include directives and only declares type
trait specializations.
GGML-CUDA Core File Dependencies
This diagram shows the current include dependencies between the core
infrastructure files of the ggml-cuda backend. Arrows point from the including
file to the included file.
Most operator headers include kernel-ops-context.cuh, which exposes the
lightweight stream/cuBLAS/pool view used by ordinary kernel entry points. Heavier
paths such as matmul, flash attention, fusion, set.cuh, and compute-forward.cuh
include backend-context.cuh when they need the full backend context.
GGML-CUDA Class & Struct Dependencies
This diagram shows the relationships between key structs and classes in the
ggml-cuda backend, including ownership, usage, and inheritance.
ggml_backend_cuda_context is the central backend state (declared in backend-context.cuh). It owns CUDA streams, cuBLAS handles, memory pools, and
stream concurrency state; CUDA graph cache ownership is enabled under USE_CUDA_GRAPH.
ggml_cuda_kernel_ops_context is the lightweight dispatch view declared in kernel-ops-context.cuh. Most operator entry points take this instead of the
full backend context.
ggml_cuda_pool is an abstract interface with two concrete implementations: ggml_cuda_pool_leg (legacy, free-list based) and ggml_cuda_pool_vmm
(virtual memory managed).
ggml_cuda_graph / ggml_cuda_graph_cache manage CUDA graph capture
and replay for reduced launch overhead.
ggml_cuda_concurrent_event is defined in concurrent-event.cuh and
encapsulates fork/join synchronization state plus stream mapping/original-order
helpers used by graph execution.
ggml_backend_cuda_comm_context is defined in multi-device-nccl.cu when
NCCL is enabled and stores the backend/NCCL communicator lists used by
all-reduce.
GGML-CUDA Op Dispatch & Module Organization
This diagram shows how operations flow from the top-level backend entry point
through graph compute, compute-forward dispatch, and into individual operator modules.
ggml-cuda.cu creates the backend and registers it with ggml.
graph.cu receives a compute graph and orchestrates execution:
Runs ggml_backend_cuda_graph_optimize() to apply fusion passes.
Optionally captures/replays CUDA graphs for reduced overhead.
Dispatches each node via ggml_cuda_compute_forward().
fusion.cu merges compatible adjacent ops (e.g., matmul + bias + activation)
into fused kernels, rewriting the graph in-place. It may call snake.cu
(ggml_cuda_op_snake_fused) for snake-style matmul fusion paths.
compute-forward.cu is a large switch over ggml_op that routes to
the appropriate operator implementation. View-like ops return without a
kernel launch; ops not listed in the switch return false.
Operator modules each implement one or more ggml ops as CUDA kernels. allreduce.cu is built in the same target and implements an internal
multi-GPU all-reduce pipeline (ggml_cuda_ar_*); wiring from multi-device-nccl.cu is optional and evolves independently of the main
graph dispatch path.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi GGML ,
I have prepared a brief summary of the work I’ve done on a branch where I explored the idea of refactoring the
ggml-cudabackend. I was surprised, how tightly coupled everything is there—almost all of the logic is concentrated within a single *.cu and *.cuh file. I analyzed the feasibility of adopting a non-monolithic architecture, the goal of which is to separate responsibilities across distinct modules (files). Below, I have included a summary outlining the most significant changes.Few questions to you: is this concept valuable enough to submitting it as a PR and asking you to invest your time in reviewing it?
Maybe the monolithic-design is expected ? If yes, than this is clear that this is not expected and worth to follow.
TL;DR
refactor-backendis a CUDA-backend refactor branch that reorganisesthe
ggml/src/ggml-cudabackend by breaking the monolithiccommon.cuhandggml-cuda.cuinto a set of focused modules, introduces a new lightweightper-op execution context, and adds developer documentation.
What the refactor aims
the line count comes from a mechanical signature/include rewrite plus moving
code between files. The existing interfaces in the ggml-cuda.cu and common.cuh are the same not modified only distributed to separate files.
Which of the above design-points are expected and worth to merge/follow ?
What the refactor does (high-level)
ggml-cuda.cu(~5.3k lines removed from it) into per-concerntranslation units:
device.cu,buffer.cu,graph.cu,mul-mat.cu,compute-forward.cu,fusion.cu,multi-device-nccl.cu,registry.cu,concurrent-event.cu.common.cuhand replaces it with several smaller, purpose-builtheaders:
cuda-defs.cuh,kernel-utils.cuh,kernel-ops-context.cuh,backend-context.cuh,cuda-pool.cuh,concurrent-event.cuh,graph.cuh,type-traits.cuh,mmq-types.cuh.ggml_cuda_kernel_ops_context- a small, const-friendly view(device id, stream no,
cudaStream_t,cublasHandle_t, optional poolaccess) that replaces
ggml_backend_cuda_context &in every CUDA op entrypoint. This is why ~125
*.cu/*.cuhop files have a small, mechanicalchange. They all need read-only stream or device_id not the whole context.
ggml/src/ggml-cuda/docs/.CUDA backend — new modules split out of
ggml-cuda.cuandcommon.cuhggml/src/ggml-cuda/device.cu/device.cuh->(Logic and entities associated with ggml_backend_cuda_device)ggml/src/ggml-cuda/buffer.cu/buffer.cuh->(Logic and entities associated with ggml_backend_cuda_buffer)ggml/src/ggml-cuda/graph.cu/graph.cuh->(Logic and entities associated for graph processing and caching)ggml/src/ggml-cuda/mul-mat.cu/mul-mat.cuh-> ( Orchestration for mul-mat associated to ggml_cuda_op_mul_mat)ggml/src/ggml-cuda/compute-forward.cu/compute-forward.cuh->( Code for dispatch of ops)ggml/src/ggml-cuda/fusion.cu/fusion.cuh-> ( Logic and entities associated with fusion of ops)ggml/src/ggml-cuda/multi-device-nccl.cu/multi-device-nccl.cuh( Logic and entities associated with gml_backend_cuda_comm_context and NCCL support)ggml/src/ggml-cuda/concurrent-event.cu/concurrent-event.cuh-> ( ggml_cuda_concurrent_event struct that encapsulates cudaEvent_t and associated logic)ggml/src/ggml-cuda/registry.cu-> ( Code and logic responsible for backend-registration)ggml/src/ggml-cuda/backend-context.cuh-> ( Main entity that includes and encapsulated graphs, streams and other related to driver objects)ggml/src/ggml-cuda/kernel-ops-context.cuhggml/src/ggml-cuda/kernel-utils.cuhggml/src/ggml-cuda/cuda-defs.cuhggml/src/ggml-cuda/cuda-pool.cuhggml/src/ggml-cuda/type-traits.cuhggml/src/ggml-cuda/mmq-types.cuhImportant: The content of these files above is only moved from
ggml-cuda.cuandcommon.cuh, NOT new created or redesigned. Only private and public type scopes are added.Documentation
ggml/src/ggml-cuda/docs/01-core-file-dependencies.mdggml/src/ggml-cuda/docs/02-class-struct-dependencies.mdggml/src/ggml-cuda/docs/03-op-dispatch-and-modules.mdggml/src/ggml-cuda/docs/04-matmul-and-attention-internals.mdggml/src/ggml-cuda/docs/05-backend-lifecycle.mdggml/src/ggml-cuda/docs/06-header-dependency-map.mdDiagram files are uploaded below in Mermaid diagrams section.
Files removed (1)
ggml/src/ggml-cuda/common.cuh— replaced by the new focused headers listedabove.
Heaviest single file change
ggml/src/ggml-cuda/ggml-cuda.cu— ~5,300 lines removed, content moved intothe new modules listed above.
Op dispatch / context plumbing (mechanical signature change)
For each CUDA op, the entry point changes from
to
and the corresponding
*.cuhswaps#include "common.cuh"for the smallerkernel-ops-context.cuh(orkernel-utils.cuhwhere needed). The same shapeof change applies to (modified) header/source pairs:
acc,add-id,unary,clamp,scale,fill,roll,set,set-rows,softcap,softmax,snake,sum,sumrows,mean,count-equal,cumsum,diag,diagmask,tri,tsembd,pad,pad_reflect_1d,upscale,arange,concat,binbcast,argmax,argsort,top-k,topk-moe,solve_tri,getrows,cpy,convert,quantize,dequantize,vecdotq,reduce_rows.norm,cross-entropy-loss,opt-step-adamw,opt-step-sgd,out-prod.conv2d,conv2d-dw,conv2d-transpose,conv-transpose-1d,im2col,pool2d.ssm-conv,ssm-scan,wkv,gla,gated_delta_net.rope,set-rows,set.fattn,fattn-tile,fattn-wmma-f16,fattn-vec.cuh,fattn-mma-f16.cuh,fattn-common.cuh.mmf,mmid,mmq,mmvf,mmvq,mma.cuh,cp-async.cuh.allreduce.(Full list visible via
git diff --name-status master...refactor-backend.)Notable design points introduced
ggml_cuda_kernel_ops_context(kernel-ops-context.cuh) is a smallPOD-like view holding
device,stream_no,cudaStream_t,cublasHandle_t, and a back-pointer to the owningggml_backend_cuda_contextfor pool access. Op functions take it byconst &, which is what enables the const-correctness changes throughoutthe backend.
ggml_backend_cuda_context(backend-context.cuh) becomes the heavyowner: streams, cuBLAS handles, CUDA-graph cache, concurrent-event tracking,
device id, scratch pools. Forward declarations of
ggml_cuda_graph/ggml_cuda_graph_cachekeep the include surface small.ggml_cuda_concurrent_event(concurrent-event.{cu,cuh}) encapsulatesper-tensor cross-stream/copy event handling that previously lived inline.
graph.{cu,cuh}module and accessedthrough methods on
ggml_backend_cuda_context.multi-device-nccl.{cu,cuh}.registry.cu.Scope notes
Mermaid diagrams
GGML-CUDA Complete Header Dependency Map
This diagram shows every
.cuhheader file and its direct#includedependencies on other
.cuhfiles within the ggml-cuda directory. Files aregrouped by functional area.
graph TD subgraph "Foundation Layer" cuda_defs["cuda-defs.cuh"] kernel_ops_ctx["kernel-ops-context.cuh"] kernel_utils["kernel-utils.cuh"] cuda_pool["cuda-pool.cuh"] concurrent_event["concurrent-event.cuh"] end subgraph "Backend Infrastructure" graph_cuh["graph.cuh"] backend_ctx["backend-context.cuh"] buffer["buffer.cuh"] multi_device_nccl["multi-device-nccl.cuh"] device["device.cuh"] compute_fwd["compute-forward.cuh"] fusion["fusion.cuh"] end subgraph "Matmul Primitives" mma["mma.cuh"] vecdotq["vecdotq.cuh"] type_traits["type-traits.cuh"] mmq_types["mmq-types.cuh"] quantize["quantize.cuh"] convert["convert.cuh"] dequantize["dequantize.cuh"] cpy_utils["cpy-utils.cuh"] cp_async["cp-async.cuh"] end subgraph "Matmul Headers" mmq["mmq.cuh"] mmf["mmf.cuh"] mmvq["mmvq.cuh"] mmvf["mmvf.cuh"] mmid["mmid.cuh"] mul_mat["mul-mat.cuh"] end subgraph "Flash Attention Headers" fattn["fattn.cuh"] fattn_common["fattn-common.cuh"] fattn_mma["fattn-mma-f16.cuh"] fattn_tile["fattn-tile.cuh"] fattn_vec["fattn-vec.cuh"] fattn_wmma["fattn-wmma-f16.cuh"] end subgraph "Operator Headers (kernel ops context)" acc["acc.cuh"] add_id["add-id.cuh"] arange["arange.cuh"] argmax["argmax.cuh"] argsort["argsort.cuh"] binbcast["binbcast.cuh"] clamp["clamp.cuh"] concat["concat.cuh"] conv_t1d["conv-transpose-1d.cuh"] conv2d["conv2d.cuh"] conv2d_dw["conv2d-dw.cuh"] conv2d_trans["conv2d-transpose.cuh"] count_equal["count-equal.cuh"] cross_ent["cross-entropy-loss.cuh"] cumsum["cumsum.cuh"] diag["diag.cuh"] diagmask["diagmask.cuh"] fill["fill.cuh"] getrows["getrows.cuh"] gla["gla.cuh"] gated_delta["gated_delta_net.cuh"] im2col["im2col.cuh"] norm["norm.cuh"] opt_adam["opt-step-adamw.cuh"] opt_sgd["opt-step-sgd.cuh"] out_prod["out-prod.cuh"] pad["pad.cuh"] pad_reflect["pad_reflect_1d.cuh"] pool2d["pool2d.cuh"] reduce_rows["reduce_rows.cuh"] roll["roll.cuh"] rope["rope.cuh"] scale["scale.cuh"] set_rows["set-rows.cuh"] softcap["softcap.cuh"] softmax["softmax.cuh"] solve_tri["solve_tri.cuh"] ssm_conv["ssm-conv.cuh"] ssm_scan["ssm-scan.cuh"] sum["sum.cuh"] sumrows["sumrows.cuh"] top_k["top-k.cuh"] topk_moe["topk-moe.cuh"] tri["tri.cuh"] tsembd["tsembd.cuh"] unary["unary.cuh"] upscale["upscale.cuh"] wkv["wkv.cuh"] end subgraph "Operator Headers (full backend context)" cpy["cpy.cuh"] mean["mean.cuh"] set["set.cuh"] end subgraph "Fusion and internal AR" snake["snake.cuh"] allreduce["allreduce.cuh"] end %% Foundation chain kernel_utils --> cuda_defs cuda_pool --> cuda_defs concurrent_event --> cuda_defs %% Backend infrastructure graph_cuh --> cuda_pool backend_ctx --> cuda_pool backend_ctx --> concurrent_event backend_ctx --> kernel_ops_ctx buffer --> cuda_defs multi_device_nccl --> cuda_defs device --> cuda_pool compute_fwd --> backend_ctx fusion --> backend_ctx snake --> kernel_ops_ctx %% Matmul primitives mma --> kernel_utils vecdotq --> kernel_utils mmq_types --> cuda_defs quantize --> cuda_defs quantize --> mmq_types convert --> cuda_defs dequantize --> cuda_defs cpy_utils --> convert cp_async --> kernel_utils %% Matmul headers mmq --> mmq_types mmq --> backend_ctx mmq --> type_traits mmq --> vecdotq mmq --> mma mmf --> mma mmf --> kernel_utils mmf --> convert mmf --> backend_ctx mmvq --> backend_ctx mmvq --> fusion mmvf --> backend_ctx mmvf --> fusion mmid --> kernel_utils mul_mat --> backend_ctx %% Flash attention fattn --> backend_ctx fattn --> kernel_utils fattn_common --> kernel_utils fattn_common --> backend_ctx fattn_common --> cuda_pool fattn_common --> convert fattn_common --> vecdotq fattn_mma --> cp_async fattn_mma --> mma fattn_mma --> fattn_common fattn_tile --> fattn_common fattn_tile --> fattn_wmma fattn_vec --> fattn_common fattn_wmma --> backend_ctx %% Operator headers that include kernel-ops-context.cuh acc --> kernel_ops_ctx add_id --> kernel_ops_ctx arange --> kernel_ops_ctx argmax --> kernel_ops_ctx argsort --> kernel_ops_ctx binbcast --> kernel_ops_ctx clamp --> kernel_ops_ctx concat --> kernel_ops_ctx conv_t1d --> kernel_ops_ctx conv2d --> kernel_ops_ctx conv2d_dw --> kernel_ops_ctx conv2d_trans --> kernel_ops_ctx count_equal --> kernel_ops_ctx cross_ent --> kernel_ops_ctx cumsum --> kernel_ops_ctx diag --> kernel_ops_ctx diagmask --> kernel_ops_ctx fill --> kernel_ops_ctx getrows --> kernel_ops_ctx gla --> kernel_ops_ctx gated_delta --> kernel_ops_ctx im2col --> kernel_ops_ctx norm --> kernel_ops_ctx opt_adam --> kernel_ops_ctx opt_sgd --> kernel_ops_ctx out_prod --> kernel_ops_ctx pad --> kernel_ops_ctx pad_reflect --> kernel_ops_ctx pool2d --> kernel_ops_ctx reduce_rows --> kernel_utils roll --> kernel_ops_ctx rope --> kernel_ops_ctx scale --> kernel_ops_ctx set_rows --> kernel_ops_ctx softcap --> kernel_ops_ctx softmax --> kernel_ops_ctx solve_tri --> kernel_ops_ctx ssm_conv --> kernel_ops_ctx ssm_scan --> kernel_ops_ctx sum --> kernel_ops_ctx sumrows --> kernel_ops_ctx top_k --> kernel_ops_ctx topk_moe --> kernel_ops_ctx tri --> kernel_ops_ctx tsembd --> kernel_ops_ctx unary --> kernel_ops_ctx upscale --> kernel_ops_ctx wkv --> kernel_ops_ctx %% Operator headers that include backend-context.cuh cpy --> backend_ctx mean --> backend_ctx set --> backend_ctx %% Styling classDef foundation fill:#d5f5e3,stroke:#27ae60,stroke-width:3px classDef infra fill:#d6eaf8,stroke:#2e86c1,stroke-width:2px classDef prim fill:#fdebd0,stroke:#e67e22 classDef matmul fill:#e8daef,stroke:#8e44ad classDef attn fill:#fce4ec,stroke:#c62828 classDef ops fill:#f5f5f5,stroke:#999 classDef backendOps fill:#fff3cd,stroke:#856404 classDef fusionAR fill:#fdebd0,stroke:#d35400,stroke-dasharray: 4 2 class cuda_defs,kernel_ops_ctx,kernel_utils,cuda_pool,concurrent_event foundation class graph_cuh,backend_ctx,buffer,multi_device_nccl,device,compute_fwd,fusion infra class mma,vecdotq,type_traits,mmq_types,quantize,convert,dequantize,cpy_utils,cp_async prim class mmq,mmf,mmvq,mmvf,mmid,mul_mat matmul class fattn,fattn_common,fattn_mma,fattn_tile,fattn_vec,fattn_wmma attn class acc,add_id,arange,argmax,argsort,binbcast,clamp,concat,conv_t1d,conv2d,conv2d_dw,conv2d_trans,count_equal,cross_ent,cumsum,diag,diagmask,fill,getrows,gla,gated_delta,im2col,norm,opt_adam,opt_sgd,out_prod,pad,pad_reflect,pool2d,reduce_rows,roll,rope,scale,set_rows,softcap,softmax,solve_tri,ssm_conv,ssm_scan,sum,sumrows,top_k,topk_moe,tri,tsembd,unary,upscale,wkv ops class cpy,mean,set backendOps class snake,allreduce fusionARSummary Statistics
.cuh.cuh.cuh.cuh.cuh.cuhkernel-ops-context.cuh;cpy,mean, andsetuse the full backend context.cuhsnake.cuh(fused matmul helper);allreduce.cuh(host-side internal AR API).cuhggml/src/ggml-cuda/.cutemplate-instances/.cu.cuStandalone Headers
type-traits.cuhhas no#includedirectives and only declares typetrait specializations.
GGML-CUDA Core File Dependencies
This diagram shows the current include dependencies between the core
infrastructure files of the ggml-cuda backend. Arrows point from the including
file to the included file.
graph TD subgraph "External GGML Headers" ggml_h["ggml.h"] ggml_impl["ggml-impl.h"] ggml_backend_impl["ggml-backend-impl.h"] ggml_cuda_h["ggml-cuda.h"] end subgraph "Vendor Abstraction" vendors_cuda["vendors/cuda.h"] vendors_hip["vendors/hip.h"] vendors_musa["vendors/musa.h"] end subgraph "Core Infrastructure" cuda_defs["cuda-defs.cuh"] cuda_pool["cuda-pool.cuh"] kernel_utils["kernel-utils.cuh"] kernel_ops_ctx["kernel-ops-context.cuh"] concurrent_event["concurrent-event.cuh"] backend_ctx["backend-context.cuh"] graph_cuh_core["graph.cuh"] end subgraph "Backend Lifecycle" ggml_cuda_cu["ggml-cuda.cu<br/>(backend entry point)"] buffer_cuh["buffer.cuh"] buffer_cu["buffer.cu"] registry_cu["registry.cu"] multi_device_nccl_cuh["multi-device-nccl.cuh"] multi_device_nccl_cu["multi-device-nccl.cu"] device_cuh["device.cuh"] device_cu["device.cu"] graph_cu["graph.cu"] end subgraph "Compute Dispatch" compute_fwd_cuh["compute-forward.cuh"] compute_fwd_cu["compute-forward.cu"] fusion_cuh["fusion.cuh"] fusion_cu["fusion.cu"] snake_cuh["snake.cuh"] mmvq_cuh["mmvq.cuh"] end %% Vendor abstraction cuda_defs -->|"conditional"| vendors_cuda cuda_defs -->|"conditional"| vendors_hip cuda_defs -->|"conditional"| vendors_musa cuda_defs --> ggml_h cuda_defs --> ggml_impl cuda_defs --> ggml_cuda_h kernel_ops_ctx -->|"conditional"| vendors_cuda kernel_ops_ctx -->|"conditional"| vendors_hip kernel_ops_ctx -->|"conditional"| vendors_musa %% Core include chain cuda_pool --> cuda_defs kernel_utils --> cuda_defs concurrent_event --> cuda_defs graph_cuh_core --> cuda_pool backend_ctx --> cuda_pool backend_ctx --> concurrent_event backend_ctx --> kernel_ops_ctx %% Backend lifecycle buffer_cuh --> cuda_defs device_cuh --> cuda_pool multi_device_nccl_cuh --> cuda_defs fusion_cuh --> backend_ctx compute_fwd_cuh --> backend_ctx %% Main entry point ggml_cuda_cu --> ggml_cuda_h ggml_cuda_cu --> ggml_impl ggml_cuda_cu --> ggml_backend_impl ggml_cuda_cu --> backend_ctx ggml_cuda_cu --> buffer_cuh ggml_cuda_cu --> graph_cuh_core %% Source file dependencies graph_cu --> graph_cuh_core graph_cu --> backend_ctx graph_cu --> buffer_cuh graph_cu --> compute_fwd_cuh graph_cu --> fusion_cuh graph_cu --> mmvq_cuh registry_cu --> ggml_backend_impl registry_cu --> multi_device_nccl_cuh registry_cu --> device_cuh multi_device_nccl_cu --> multi_device_nccl_cuh buffer_cu --> buffer_cuh buffer_cu --> backend_ctx device_cu --> device_cuh device_cu --> buffer_cuh fusion_cu --> fusion_cuh fusion_cu --> backend_ctx fusion_cu --> buffer_cuh fusion_cu --> snake_cuh snake_cuh --> kernel_ops_ctx compute_fwd_cu --> compute_fwd_cuh %% Styling classDef external fill:#e8daef,stroke:#7d3c98 classDef vendor fill:#fadbd8,stroke:#c0392b classDef core fill:#d5f5e3,stroke:#27ae60,stroke-width:2px classDef lifecycle fill:#d6eaf8,stroke:#2e86c1 classDef dispatch fill:#fdebd0,stroke:#e67e22 class ggml_h,ggml_impl,ggml_backend_impl,ggml_cuda_h external class vendors_cuda,vendors_hip,vendors_musa vendor class cuda_defs,cuda_pool,kernel_utils,kernel_ops_ctx,concurrent_event,backend_ctx,graph_cuh_core core class ggml_cuda_cu,buffer_cuh,buffer_cu,registry_cu,multi_device_nccl_cuh,multi_device_nccl_cu,device_cuh,device_cu,graph_cu lifecycle class compute_fwd_cuh,compute_fwd_cu,fusion_cuh,fusion_cu,snake_cuh,mmvq_cuh dispatchCore Include Chain
The current foundation is split between two context layers:
Most operator headers include
kernel-ops-context.cuh, which exposes thelightweight stream/cuBLAS/pool view used by ordinary kernel entry points. Heavier
paths such as matmul, flash attention, fusion,
set.cuh, andcompute-forward.cuhinclude
backend-context.cuhwhen they need the full backend context.GGML-CUDA Class & Struct Dependencies
This diagram shows the relationships between key structs and classes in the
ggml-cuda backend, including ownership, usage, and inheritance.
classDiagram direction TB class ggml_cuda_device_info { +int device_count +cuda_device_info devices[] +array~float~ default_tensor_split } class cuda_device_info { +int cc +int nsm +size_t smpb +size_t smpbo +bool integrated +bool vmm +size_t vmm_granularity +size_t total_vram +int warp_size +bool supports_cooperative_launch } class ggml_cuda_pool { <<abstract>> +~ggml_cuda_pool()* +alloc(size, actual_size)* void* +free(ptr, size)* } class ggml_cuda_pool_leg { -int device -ggml_cuda_buffer buffer_pool[] -size_t pool_size +alloc(size, actual_size) void* +free(ptr, size) } class ggml_cuda_pool_vmm { -int device -CUdeviceptr pool_addr -size_t pool_used -size_t pool_size -size_t granularity +alloc(size, actual_size) void* +free(ptr, size) } class ggml_cuda_pool_alloc~T~ { -ggml_cuda_pool* pool -T* ptr -size_t actual_size +alloc(size) T* +get() T* } class ggml_backend_cuda_context { -int device -int curr_stream_no -cudaEvent_t copy_event -string name -cudaStream_t streams[][] -cublasHandle_t cublas_handles[] -unique_ptr~ggml_cuda_pool~ pools[][] -ggml_cuda_stream_context concurrent_stream_context -ggml_cuda_graph_cache graph_cache (USE_CUDA_GRAPH) +stream() cudaStream_t const +stream(device, stream_no) cudaStream_t const +pool() ggml_cuda_pool& const +pool(device) ggml_cuda_pool& const +cublas_handle() cublasHandle_t const +cublas_handle(device) cublasHandle_t const +device_id() int +cuda_graph(key) ggml_cuda_graph* +concurrent_events() +kernel_ops_context() ggml_cuda_kernel_ops_context } class ggml_cuda_kernel_ops_context { -int device_ -int stream_no_ -cudaStream_t stream_ -cublasHandle_t cublas_ -ggml_backend_cuda_context* backend_for_pool +stream() cudaStream_t +cublas_handle() cublasHandle_t +device_id() int +stream_no() int +pool() ggml_cuda_pool& +pool(device) ggml_cuda_pool& } class ggml_cuda_stream_context { +concurrent_events : unordered_map +reset() } class ggml_cuda_concurrent_event { +ggml_cuda_concurrent_event() +ggml_cuda_concurrent_event(n_streams) +ggml_cuda_concurrent_event(move) +~ggml_cuda_concurrent_event() +n_streams() int +fork_event() cudaEvent_t +join_event(stream_no) cudaEvent_t +set_join_event(stream_no, event) +set_join_node(node) +join_node() const ggml_tensor* +set_stream_for(tensor, stream_no) +has_stream_for(tensor) bool +stream_for(tensor) int +reserve_original_order(n) +push_original_order(node) +original_order_size() size_t +original_order() vector~const ggml_tensor*~ +is_valid() bool -vector~cudaEvent_t~ join_events_ -cudaEvent_t fork_event_ -int n_streams_ -unordered_map~const ggml_tensor*, int~ stream_mapping_ -vector~const ggml_tensor*~ original_order_ -const ggml_tensor* join_node_ } class ggml_cuda_graph { -cudaGraph_t graph -cudaGraphExec_t instance -bool disable_due_to_gpu_arch -bool warmup_complete -vector~node_properties~ node_props +is_enabled() bool +has_instance() bool +update_properties(cgraph) bool +advance_warmup(changed) warmup_result +begin_capture(stream) +end_capture(stream) +instantiate_or_update() +launch(stream) } class ggml_cuda_graph_cache { -unordered_map~void_ptr,graph_ptr~ graphs -int64_t last_eviction_sweep +get(key) ggml_cuda_graph* +any_enabled() bool +any_has_instance() bool } class ggml_cuda_graph_deleter { +operator()(graph) } class ggml_tensor_extra_gpu { +void* data_device[] +cudaEvent_t events[][] } class ggml_cuda_mm_fusion_args_host { +const ggml_tensor* x_bias +const ggml_tensor* gate +const ggml_tensor* gate_bias +ggml_glu_op glu_op } class ggml_cuda_mm_fusion_args_device { +const void* x_bias +const void* gate +const void* gate_bias +ggml_glu_op glu_op } class ggml_backend_cuda_buffer_context { +int device +void* dev_ptr +string name } class ggml_backend_cuda_device_context { +int device +string name +string description +string pci_bus_id +int op_offload_min_batch_size } class ggml_backend_cuda_reg_context { +vector~ggml_backend_dev_t~ devices } class ggml_backend_cuda_comm_context { +vector~ggml_backend_t~ backends +vector~ncclComm_t~ comms } %% Inheritance ggml_cuda_pool <|-- ggml_cuda_pool_leg : implements ggml_cuda_pool <|-- ggml_cuda_pool_vmm : implements %% Nested ggml_cuda_device_info *-- cuda_device_info : contains %% Ownership / composition ggml_backend_cuda_context *-- ggml_cuda_pool : owns via unique_ptr ggml_backend_cuda_context *-- ggml_cuda_graph_cache : owns (USE_CUDA_GRAPH) ggml_backend_cuda_context *-- ggml_cuda_stream_context : owns ggml_backend_cuda_context ..> ggml_cuda_kernel_ops_context : creates view ggml_cuda_graph_cache *-- ggml_cuda_graph : owns via unique_ptr ggml_cuda_graph_cache ..> ggml_cuda_graph_deleter : uses ggml_cuda_stream_context *-- ggml_cuda_concurrent_event : owns map of %% Usage / references ggml_cuda_pool_alloc ..> ggml_cuda_pool : references ggml_cuda_kernel_ops_context ..> ggml_backend_cuda_context : optional pool owner ggml_cuda_kernel_ops_context ..> ggml_cuda_pool : scratch allocation ggml_backend_cuda_reg_context *-- ggml_backend_dev_t : owns vector of ggml_backend_cuda_comm_context ..> ggml_backend_cuda_context : allreduce streamsKey Relationships
ggml_backend_cuda_contextis the central backend state (declared inbackend-context.cuh). It owns CUDA streams, cuBLAS handles, memory pools, andstream concurrency state; CUDA graph cache ownership is enabled under
USE_CUDA_GRAPH.ggml_cuda_kernel_ops_contextis the lightweight dispatch view declared inkernel-ops-context.cuh. Most operator entry points take this instead of thefull backend context.
ggml_cuda_poolis an abstract interface with two concrete implementations:ggml_cuda_pool_leg(legacy, free-list based) andggml_cuda_pool_vmm(virtual memory managed).
ggml_cuda_graph/ggml_cuda_graph_cachemanage CUDA graph captureand replay for reduced launch overhead.
ggml_cuda_concurrent_eventis defined inconcurrent-event.cuhandencapsulates fork/join synchronization state plus stream mapping/original-order
helpers used by graph execution.
ggml_backend_cuda_comm_contextis defined inmulti-device-nccl.cuwhenNCCL is enabled and stores the backend/NCCL communicator lists used by
all-reduce.
GGML-CUDA Op Dispatch & Module Organization
This diagram shows how operations flow from the top-level backend entry point
through graph compute, compute-forward dispatch, and into individual operator modules.
graph TD subgraph "Entry Point" ggml_cuda_cu["ggml-cuda.cu<br/>ggml_backend_cuda_init()"] end subgraph "Graph Execution" graph_cu["graph.cu<br/>ggml_backend_cuda_graph_compute()"] graph_opt["graph.cu<br/>ggml_backend_cuda_graph_optimize()"] end subgraph "Fusion" fusion["fusion.cu<br/>ggml_cuda_try_fuse()"] snake_cu["snake.cu<br/>ggml_cuda_op_snake_fused()"] end subgraph "Compute Dispatch" compute_fwd["compute-forward.cu<br/>ggml_cuda_compute_forward()"] end subgraph "Matrix Multiply Stack" mul_mat["mul-mat.cu<br/>ggml_cuda_mul_mat()"] mmq["mmq.cu — quantized matmul"] mmf["mmf.cu — float MMA matmul"] mmvq["mmvq.cu — vec quantized matmul"] mmvf["mmvf.cu — vec float matmul"] mmid["mmid.cu — routed/MoE matmul"] end subgraph "Flash Attention Stack" fattn["fattn.cu<br/>ggml_cuda_flash_attn_ext()"] fattn_mma["fattn-mma-f16.cuh<br/>MMA path"] fattn_tile["fattn-tile.cu<br/>tile path"] fattn_wmma["fattn-wmma-f16.cu<br/>WMMA path"] fattn_vec["fattn-vec.cuh<br/>vec path"] fattn_common["fattn-common.cuh<br/>shared utilities"] end subgraph "Elementwise & Unary Ops" add_id["add-id.cu"] acc["acc.cu"] unary["unary.cu"] binbcast["binbcast.cu"] scale["scale.cu"] clamp["clamp.cu"] softmax["softmax.cu"] softcap["softcap.cu"] end subgraph "Normalization" norm["norm.cu"] end subgraph "Data Movement" cpy["cpy.cu"] convert["convert.cu"] getrows["getrows.cu"] set_rows["set-rows.cu"] concat["concat.cu"] pad["pad.cu"] pad_reflect["pad_reflect_1d.cu"] roll["roll.cu"] end subgraph "Positional Encoding" rope["rope.cu"] tsembd["tsembd.cu"] end subgraph "Reduction" sum["sum.cu"] sumrows["sumrows.cu"] mean["mean.cu"] argmax["argmax.cu"] argsort["argsort.cu"] top_k["top-k.cu"] topk_moe["topk-moe.cu"] count_equal["count-equal.cu"] end subgraph "Convolution" conv2d["conv2d.cu"] conv2d_dw["conv2d-dw.cu"] conv2d_trans["conv2d-transpose.cu"] conv_t1d["conv-transpose-1d.cu"] im2col["im2col.cu"] pool2d["pool2d.cu"] end subgraph "Recurrent / SSM" ssm_conv["ssm-conv.cu"] ssm_scan["ssm-scan.cu"] gla["gla.cu"] wkv["wkv.cu"] gated_delta["gated_delta_net.cu"] end subgraph "Other Ops" arange["arange.cu"] cumsum["cumsum.cu"] diag["diag.cu"] diagmask["diagmask.cu"] fill["fill.cu"] out_prod["out-prod.cu"] set["set.cu"] tri["tri.cu"] upscale["upscale.cu"] cross_ent["cross-entropy-loss.cu"] solve_tri["solve_tri.cu"] end subgraph "No Kernel Launch" no_launch["NONE / RESHAPE / VIEW<br/>PERMUTE / TRANSPOSE"] unsupported["default case<br/>return false"] end subgraph "Optimizers" opt_adam["opt-step-adamw.cu"] opt_sgd["opt-step-sgd.cu"] end %% Flow ggml_cuda_cu --> graph_cu graph_cu --> graph_opt graph_opt --> fusion graph_cu --> compute_fwd graph_cu --> mmvq fusion -->|"snake matmul fusion"| snake_cu fusion -->|"fuses into matmul"| mmvq fusion -->|"fuses into matmul"| mmvf fusion -->|"fuses norms"| norm fusion -->|"fuses rope"| rope fusion -->|"fuses unary"| unary fusion -->|"fuses binbcast"| binbcast fusion -->|"fuses softcap"| softcap fusion -->|"fuses topk-moe"| topk_moe fusion -->|"fuses ssm-conv"| ssm_conv compute_fwd --> no_launch compute_fwd --> unsupported compute_fwd --> mul_mat compute_fwd --> fattn compute_fwd --> unary compute_fwd --> binbcast compute_fwd --> acc compute_fwd --> add_id compute_fwd --> arange compute_fwd --> argmax compute_fwd --> argsort compute_fwd --> clamp compute_fwd --> conv2d compute_fwd --> conv2d_dw compute_fwd --> conv2d_trans compute_fwd --> conv_t1d compute_fwd --> count_equal compute_fwd --> cross_ent compute_fwd --> cumsum compute_fwd --> diag compute_fwd --> diagmask compute_fwd --> fill compute_fwd --> gla compute_fwd --> im2col compute_fwd --> mean compute_fwd --> norm compute_fwd --> opt_adam compute_fwd --> opt_sgd compute_fwd --> out_prod compute_fwd --> pad compute_fwd --> pad_reflect compute_fwd --> pool2d compute_fwd --> roll compute_fwd --> softmax compute_fwd --> rope compute_fwd --> scale compute_fwd --> set compute_fwd --> set_rows compute_fwd --> solve_tri compute_fwd --> tri compute_fwd --> tsembd compute_fwd --> upscale compute_fwd --> cpy compute_fwd --> getrows compute_fwd --> concat compute_fwd --> gated_delta compute_fwd --> top_k compute_fwd --> sum compute_fwd --> sumrows compute_fwd --> ssm_conv compute_fwd --> ssm_scan compute_fwd --> wkv mul_mat --> mmq mul_mat --> mmf mul_mat --> mmvq mul_mat --> mmvf mul_mat --> mmid fattn --> fattn_mma fattn --> fattn_tile fattn --> fattn_wmma fattn --> fattn_vec fattn_mma --> fattn_common fattn_tile --> fattn_common fattn_wmma --> fattn_common fattn_vec --> fattn_common %% Styling classDef entry fill:#f9e79f,stroke:#f39c12,stroke-width:2px classDef graphExec fill:#d5f5e3,stroke:#27ae60 classDef fusion fill:#fadbd8,stroke:#e74c3c classDef dispatch fill:#d6eaf8,stroke:#2e86c1,stroke-width:2px classDef matmul fill:#e8daef,stroke:#8e44ad classDef attn fill:#fce4ec,stroke:#c62828 classDef ops fill:#f5f5f5,stroke:#888 classDef noop fill:#eeeeee,stroke:#666,stroke-dasharray: 5 5 class ggml_cuda_cu entry class graph_cu,graph_opt graphExec class fusion,snake_cu fusion class compute_fwd dispatch class mul_mat,mmq,mmf,mmvq,mmvf,mmid matmul class fattn,fattn_mma,fattn_tile,fattn_wmma,fattn_vec,fattn_common attn class no_launch,unsupported noopExecution Flow
ggml-cuda.cucreates the backend and registers it with ggml.graph.cureceives a compute graph and orchestrates execution:ggml_backend_cuda_graph_optimize()to apply fusion passes.ggml_cuda_compute_forward().fusion.cumerges compatible adjacent ops (e.g., matmul + bias + activation)into fused kernels, rewriting the graph in-place. It may call
snake.cu(
ggml_cuda_op_snake_fused) for snake-style matmul fusion paths.compute-forward.cuis a large switch overggml_opthat routes tothe appropriate operator implementation. View-like ops return without a
kernel launch; ops not listed in the switch return
false.allreduce.cuis built in the same target and implements an internalmulti-GPU all-reduce pipeline (
ggml_cuda_ar_*); wiring frommulti-device-nccl.cuis optional and evolves independently of the maingraph dispatch path.
Beta Was this translation helpful? Give feedback.
All reactions