Skip to content

Commit aef3e1e

Browse files
CUDA header declarations for (LayerNorm) forward and backward (#66)
* feat(cuda): add attention forward backward kernel declarations (#64) * docs: report [run_20260530_165216] (~791 tok/s) Includes metrics for generalization gap, throughput (~791 tok/s), and gradient norms. Parameters: 6.68M | lr: 1e-3 | batch: 16 | steps: 6000 - Achieved best validation loss of 4.1319 at step 3900 * docs:report [run_20260530_165216](~791 tok/s) (#61) Includes metrics for generalization gap, throughput (~791 tok/s), and gradient norms. Parameters: 6.68M | lr: 1e-3 | batch: 16 | steps: 6000 - Achieved best validation loss of 4.1319 at step 3900 Co-authored-by: Max <eamon5174@gmail.com> * feat(cuda): add attention forward and backward kernel declarations Introduces the header declarations for `attention_forward` and `attention_backward` operations inside the `quadtrix::cuda` namespace. Configured with support for custom CUDA streams and head partitioning. --------- Co-authored-by: Max <eamon5174@gmail.com> * feat(cuda): add checkpoint metadata struct and stub functions * feat(cuda): introduce core type definitions and error handling utilities - Defines `DType` and `DeviceKind` enums supporting standard types (F32, F16, BF16, I32, U8). - Implements `dtype_name` and `dtype_size` metadata helper functions. - Adds an explicit `Status` struct for non-throwing error propagation alongside `checked_mul` for safe allocation size computation. - Introduces `check_cuda` and `abort_on_cuda` error macros and handling mechanisms, exposed via the `QUADTRIX_CUDA_CHECK` macro. * feat(cuda): add TokenBatchView struct and DataLoader stub class * feat(cuda): add GeLU activation forward and backward declarations - Introduces the `GeluMode` enum to toggle between `Exact` and `Approximate` mathematical variants. - Declares the `gelu_forward` and `gelu_backward` kernel entrypoints. - Configures both signatures with optional stream execution and a default mode of `GeluMode::Approximate`. * feat(cuda): add gradient norm calculation and clipping interfaces --------- Co-authored-by: Max <eamon5174@gmail.com>
1 parent f1cd13d commit aef3e1e

5 files changed

Lines changed: 231 additions & 0 deletions

File tree

CUDA/includes/checkpoint.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include "tensor.cuh"
4+
5+
namespace quadtrix {
6+
namespace cuda {
7+
8+
struct CheckpointMetadata {
9+
int vocab_size = 0;
10+
int max_sequence_length = 0;
11+
int num_layers = 0;
12+
int num_heads = 0;
13+
int channels = 0;
14+
};
15+
16+
inline bool load_checkpoint_metadata(const char*, CheckpointMetadata*) {
17+
return false;
18+
}
19+
20+
inline bool save_tensor_checkpoint(const char*, const TensorView&) {
21+
return false;
22+
}
23+
24+
} // namespace cuda
25+
} // namespace quadtrix

CUDA/includes/common.h

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#pragma once
2+
3+
#include <cuda_runtime.h>
4+
5+
#include <cstddef>
6+
#include <cstdint>
7+
#include <cstdio>
8+
#include <cstdlib>
9+
#include <limits>
10+
11+
namespace quadtrix {
12+
namespace cuda {
13+
14+
enum class DType : std::uint8_t {
15+
F32,
16+
F16,
17+
BF16,
18+
I32,
19+
U8,
20+
};
21+
22+
enum class DeviceKind : std::uint8_t {
23+
CPU,
24+
CUDA,
25+
};
26+
27+
struct Status {
28+
bool ok;
29+
cudaError_t cuda_error;
30+
const char* message;
31+
32+
static Status success() {
33+
return {true, cudaSuccess, "ok"};
34+
}
35+
36+
static Status failure(cudaError_t error, const char* message) {
37+
return {false, error, message};
38+
}
39+
};
40+
41+
inline const char* dtype_name(DType dtype) {
42+
switch (dtype) {
43+
case DType::F32:
44+
return "f32";
45+
case DType::F16:
46+
return "f16";
47+
case DType::BF16:
48+
return "bf16";
49+
case DType::I32:
50+
return "i32";
51+
case DType::U8:
52+
return "u8";
53+
}
54+
return "unknown";
55+
}
56+
57+
inline std::size_t dtype_size(DType dtype) {
58+
switch (dtype) {
59+
case DType::F32:
60+
return 4;
61+
case DType::F16:
62+
return 2;
63+
case DType::BF16:
64+
return 2;
65+
case DType::I32:
66+
return 4;
67+
case DType::U8:
68+
return 1;
69+
}
70+
71+
std::fprintf(stderr, "Unknown CUDA dtype value %u\n", static_cast<unsigned int>(dtype));
72+
std::abort();
73+
}
74+
75+
inline bool checked_mul(std::size_t lhs, std::size_t rhs, std::size_t* out) {
76+
if (lhs != 0 && rhs > std::numeric_limits<std::size_t>::max() / lhs) {
77+
return false;
78+
}
79+
*out = lhs * rhs;
80+
return true;
81+
}
82+
83+
inline Status check_cuda(cudaError_t error, const char* expression, const char* file, int line) {
84+
if (error == cudaSuccess) {
85+
return Status::success();
86+
}
87+
88+
std::fprintf(
89+
stderr,
90+
"CUDA error at %s:%d: %s failed with %s\n",
91+
file,
92+
line,
93+
expression,
94+
cudaGetErrorString(error));
95+
return Status::failure(error, expression);
96+
}
97+
98+
inline void abort_on_cuda(cudaError_t error, const char* expression, const char* file, int line) {
99+
if (error == cudaSuccess) {
100+
return;
101+
}
102+
103+
std::fprintf(
104+
stderr,
105+
"Fatal CUDA error at %s:%d: %s failed with %s\n",
106+
file,
107+
line,
108+
expression,
109+
cudaGetErrorString(error));
110+
std::abort();
111+
}
112+
113+
} // namespace cuda
114+
} // namespace quadtrix
115+
116+
#define QUADTRIX_CUDA_CHECK(expr) \
117+
::quadtrix::cuda::check_cuda((expr), #expr, __FILE__, __LINE__)
118+
119+
#define QUADTRIX_CUDA_ABORT(expr) \
120+
::quadtrix::cuda::abort_on_cuda((expr), #expr, __FILE__, __LINE__)

CUDA/includes/dataloader.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <cstddef>
4+
#include <cstdint>
5+
6+
namespace quadtrix {
7+
namespace cuda {
8+
9+
struct TokenBatchView {
10+
const std::int32_t* inputs = nullptr;
11+
const std::int32_t* targets = nullptr;
12+
int batch_size = 0;
13+
int sequence_length = 0;
14+
};
15+
16+
class DataLoader {
17+
public:
18+
DataLoader() = default;
19+
20+
bool next(TokenBatchView* batch) {
21+
if (batch != nullptr) {
22+
*batch = {};
23+
}
24+
return false;
25+
}
26+
};
27+
28+
} // namespace cuda
29+
} // namespace quadtrix

CUDA/includes/gelu.cuh

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#include "tensor.cuh"
4+
5+
#include <cuda_runtime.h>
6+
7+
#include <cstdint>
8+
9+
namespace quadtrix {
10+
namespace cuda {
11+
12+
enum class GeluMode : std::uint8_t {
13+
Exact,
14+
Approximate,
15+
};
16+
17+
Status gelu_forward(
18+
const TensorView& input,
19+
TensorView output,
20+
GeluMode mode = GeluMode::Approximate,
21+
cudaStream_t stream = nullptr);
22+
23+
Status gelu_backward(
24+
const TensorView& grad_output,
25+
const TensorView& input,
26+
TensorView grad_input,
27+
GeluMode mode = GeluMode::Approximate,
28+
cudaStream_t stream = nullptr);
29+
30+
} // namespace cuda
31+
} // namespace quadtrix

CUDA/includes/global_norm.cuh

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include "tensor.cuh"
4+
5+
#include <cuda_runtime.h>
6+
7+
namespace quadtrix {
8+
namespace cuda {
9+
10+
Status global_norm_squared(
11+
const TensorView& grads,
12+
TensorView partial_sums,
13+
cudaStream_t stream = nullptr);
14+
15+
Status clip_gradients_by_global_norm(
16+
TensorView grads,
17+
float global_norm,
18+
float max_norm,
19+
cudaStream_t stream = nullptr);
20+
21+
inline float clip_scale(float global_norm, float max_norm) {
22+
return global_norm > max_norm && global_norm > 0.0f ? max_norm / global_norm : 1.0f;
23+
}
24+
25+
} // namespace cuda
26+
} // namespace quadtrix

0 commit comments

Comments
 (0)