diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3d31c23bb..84c255747 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -2,6 +2,7 @@ name: Build and test on: pull_request: push: + branches: [ "main", "feature/correct-code" ] # 加上你的分支 paths-ignore: - '**.md' - 'LICENSE' diff --git a/include/llaisys/models/qwen2.h b/include/llaisys/models/qwen2.h index 7054626d4..2bbee1f54 100644 --- a/include/llaisys/models/qwen2.h +++ b/include/llaisys/models/qwen2.h @@ -31,12 +31,14 @@ __C { struct LlaisysQwen2Model; - __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice); + __export struct LlaisysQwen2Model *llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, llaisysDeviceType_t device, int *device_ids, int ndevice) noexcept; - __export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model); + __export void llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model) noexcept; - __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model); + __export struct LlaisysQwen2Weights *llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model) noexcept; + + __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken) noexcept; + - __export int64_t llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken); } #endif // LLAISYS_MODELS_QWEN2_H diff --git a/python/llaisys/models/qwen2.py b/python/llaisys/models/qwen2.py index 0d07b0b21..add181723 100644 --- a/python/llaisys/models/qwen2.py +++ b/python/llaisys/models/qwen2.py @@ -1,33 +1,398 @@ -from typing import Sequence -from ..libllaisys import LIB_LLAISYS -from ..libllaisys import DeviceType - +from __future__ import annotations +from typing import Sequence, Optional from pathlib import Path +import ctypes +import numpy as np import safetensors +import json +from ..libllaisys import LIB_LLAISYS +from ..libllaisys import DeviceType, DataType + +import torch + + +llaisysTensor_t = ctypes.c_void_p + + +class LlaisysQwen2Meta(ctypes.Structure): + _fields_ = [ + ("dtype", ctypes.c_int), # llaisysDataType_t + ("nlayer", ctypes.c_size_t), + ("hs", ctypes.c_size_t), + ("nh", ctypes.c_size_t), + ("nkvh", ctypes.c_size_t), + ("dh", ctypes.c_size_t), + ("di", ctypes.c_size_t), + ("maxseq", ctypes.c_size_t), + ("voc", ctypes.c_size_t), + ("epsilon", ctypes.c_float), + ("theta", ctypes.c_float), + ("end_token", ctypes.c_int64), + ] -class Qwen2: +class LlaisysQwen2Weights(ctypes.Structure): + _fields_ = [ + ("in_embed", llaisysTensor_t), + ("out_embed", llaisysTensor_t), + ("out_norm_w", llaisysTensor_t), + ("attn_norm_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_q_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_q_b", ctypes.POINTER(llaisysTensor_t)), + ("attn_k_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_k_b", ctypes.POINTER(llaisysTensor_t)), + ("attn_v_w", ctypes.POINTER(llaisysTensor_t)), + ("attn_v_b", ctypes.POINTER(llaisysTensor_t)), + ("attn_o_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_norm_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_gate_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_up_w", ctypes.POINTER(llaisysTensor_t)), + ("mlp_down_w", ctypes.POINTER(llaisysTensor_t)), + ] + + +def _dtype_is_bf16(dt) -> bool: + return int(dt) == int(DataType.BF16) + + +def _dtype_is_f16(dt) -> bool: + return int(dt) == int(DataType.F16) + + +def _dtype_is_f32(dt) -> bool: + return int(dt) == int(DataType.F32) + + +class Qwen2: def __init__(self, model_path, device: DeviceType = DeviceType.CPU): - # TODO: Implement model constructor + #self.lib = LIB_LLAISYS.lib + self.lib = LIB_LLAISYS + + self.device = device model_path = Path(model_path) + cfg_path = model_path / "config.json" + if not cfg_path.exists(): + raise FileNotFoundError(f"Missing config.json in {model_path}") + + cfg = json.loads(cfg_path.read_text(encoding="utf-8")) + + # ---- model meta ---- + nlayer = int(cfg["num_hidden_layers"]) + hs = int(cfg["hidden_size"]) + nh = int(cfg["num_attention_heads"]) + nkvh = int(cfg.get("num_key_value_heads", nh)) + di = int(cfg["intermediate_size"]) + maxseq = int(cfg.get("max_position_embeddings", 32768)) + voc = int(cfg["vocab_size"]) + dh = hs // nh + + eps = float(cfg.get("rms_norm_eps", 1e-6)) + theta = float(cfg.get("rope_theta", 10000.0)) + + # eos token + eos = cfg.get("eos_token_id", -1) + if isinstance(eos, list): + end_token = int(eos[0]) + else: + end_token = int(eos) + + # assignment uses bf16 + dtype = DataType.BF16 + + self.meta = LlaisysQwen2Meta( + dtype=int(dtype), + nlayer=nlayer, + hs=hs, + nh=nh, + nkvh=nkvh, + dh=dh, + di=di, + maxseq=maxseq, + voc=voc, + epsilon=eps, + theta=theta, + end_token=end_token, + ) + + # ---- bind APIs ---- + self.lib.llaisysQwen2ModelCreate.argtypes = [ + ctypes.POINTER(LlaisysQwen2Meta), + ctypes.c_int, + ctypes.POINTER(ctypes.c_int), + ctypes.c_int, + ] + self.lib.llaisysQwen2ModelCreate.restype = ctypes.c_void_p + + self.lib.llaisysQwen2ModelDestroy.argtypes = [ctypes.c_void_p] + self.lib.llaisysQwen2ModelDestroy.restype = None + + self.lib.llaisysQwen2ModelWeights.argtypes = [ctypes.c_void_p] + self.lib.llaisysQwen2ModelWeights.restype = ctypes.POINTER(LlaisysQwen2Weights) + + self.lib.llaisysQwen2ModelInfer.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_int64), + ctypes.c_size_t, + ] + self.lib.llaisysQwen2ModelInfer.restype = ctypes.c_int64 + + # tensorLoad already correct in your bindings + self.lib.tensorLoad.argtypes = [llaisysTensor_t, ctypes.c_void_p] + self.lib.tensorLoad.restype = None + + # tensorGetData helper (for tied lm_head fallback) + self.lib.tensorGetData.argtypes = [llaisysTensor_t] + self.lib.tensorGetData.restype = ctypes.c_void_p + + # ---- create backend ---- + self.model = self.lib.llaisysQwen2ModelCreate( + ctypes.byref(self.meta), + int(device), + None, + 0, + ) + if not self.model: + raise RuntimeError("Failed to create backend Qwen2 model") + + self.weights = self.lib.llaisysQwen2ModelWeights(self.model).contents + + # ---- weight loader ---- + def to_backend_bits(arr: np.ndarray) -> np.ndarray: + """ + Return a contiguous numpy array whose raw bytes match backend tensor dtype. + For BF16/F16 backend, we pass uint16 bits. + """ + if _dtype_is_bf16(dtype): + # safetensors may produce: + # - numpy uint16 already (bf16 bits) + # - numpy bfloat16 (rare) + # - float32 (shouldn't for model weights, but keep robust) + if arr.dtype == np.uint16: + return np.ascontiguousarray(arr) + if str(arr.dtype) == "bfloat16": + return np.ascontiguousarray(arr.view(np.uint16)) + if arr.dtype == np.float32: + x = arr.astype(np.float32, copy=False) + bits = x.view(np.uint32) + rb = (0x7FFF + ((bits >> 16) & 1)).astype(np.uint32) + bf16 = ((bits + rb) >> 16).astype(np.uint16) + return np.ascontiguousarray(bf16) + # float16 -> float32 -> bf16 + if arr.dtype == np.float16: + x = arr.astype(np.float32) + bits = x.view(np.uint32) + rb = (0x7FFF + ((bits >> 16) & 1)).astype(np.uint32) + bf16 = ((bits + rb) >> 16).astype(np.uint16) + return np.ascontiguousarray(bf16) + raise TypeError(f"Unsupported weight dtype for BF16 backend: {arr.dtype}") + + if _dtype_is_f16(dtype): + if arr.dtype == np.uint16: + return np.ascontiguousarray(arr) + if arr.dtype == np.float16: + return np.ascontiguousarray(arr.view(np.uint16)) + if arr.dtype == np.float32: + return np.ascontiguousarray(arr.astype(np.float16).view(np.uint16)) + raise TypeError(f"Unsupported weight dtype for F16 backend: {arr.dtype}") - for file in sorted(model_path.glob("*.safetensors")): - data_ = safetensors.safe_open(file, framework="numpy", device="cpu") - for name_ in data_.keys(): - ## TODO: load the model weights + # F32 + return np.ascontiguousarray(arr.astype(np.float32, copy=False)) + + def load(handle: llaisysTensor_t, arr: np.ndarray): + a = to_backend_bits(arr) + self.lib.tensorLoad(handle, a.ctypes.data_as(ctypes.c_void_p)) + + # For robustness: accept both "model.*" and "model.model.*" + # def normalize_name(name: str) -> str: + # if name.startswith("model.model."): + # return "model." + name[len("model.model.") :] + # return name + def normalize_name(name: str) -> str: + return ("model." + name[len("model.model.") :]) if name.startswith("model.model.") else name + + # flags to detect if lm_head loaded + lm_head_loaded = False + + st_files = sorted(model_path.glob("*.safetensors")) + if not st_files: + raise FileNotFoundError(f"No *.safetensors found in {model_path}") + + for file in st_files: + #data_ = safetensors.safe_open(file, framework="numpy", device="cpu") + + data_ = safetensors.safe_open(file, framework="pt", device="cpu") + + for raw_name in data_.keys(): + name = normalize_name(raw_name) + # arr = data_.get_tensor(raw_name) + t = data_.get_tensor(raw_name) # torch.Tensor + # 统一转成 uint16 bits(BF16/F16 都是 2 bytes) + if t.dtype == getattr(__import__("torch"), "bfloat16"): + arr = t.contiguous().view(__import__("torch").uint16).cpu().numpy() + elif t.dtype == getattr(__import__("torch"), "float16"): + arr = t.contiguous().view(__import__("torch").uint16).cpu().numpy() + elif t.dtype == getattr(__import__("torch"), "float32"): + arr = t.contiguous().cpu().numpy() + else: + arr = t.contiguous().cpu().numpy() + + # embeddings / final norm / lm head + if name == "model.embed_tokens.weight": + load(self.weights.in_embed, arr) + continue + if name in ("lm_head.weight", "model.lm_head.weight"): + load(self.weights.out_embed, arr) + lm_head_loaded = True + continue + if name == "model.norm.weight": + load(self.weights.out_norm_w, arr) + continue + + if not name.startswith("model.layers."): + continue + + parts = name.split(".") + if len(parts) < 4: + continue + try: + li = int(parts[2]) + except Exception: + continue + + suffix = ".".join(parts[3:]) + + if suffix == "input_layernorm.weight": + load(self.weights.attn_norm_w[li], arr) + + elif suffix == "self_attn.q_proj.weight": + load(self.weights.attn_q_w[li], arr) + elif suffix == "self_attn.q_proj.bias": + load(self.weights.attn_q_b[li], arr) + + elif suffix == "self_attn.k_proj.weight": + load(self.weights.attn_k_w[li], arr) + elif suffix == "self_attn.k_proj.bias": + load(self.weights.attn_k_b[li], arr) + + elif suffix == "self_attn.v_proj.weight": + load(self.weights.attn_v_w[li], arr) + elif suffix == "self_attn.v_proj.bias": + load(self.weights.attn_v_b[li], arr) + + elif suffix == "self_attn.o_proj.weight": + load(self.weights.attn_o_w[li], arr) + + elif suffix == "post_attention_layernorm.weight": + load(self.weights.mlp_norm_w[li], arr) + + elif suffix == "mlp.gate_proj.weight": + load(self.weights.mlp_gate_w[li], arr) + elif suffix == "mlp.up_proj.weight": + load(self.weights.mlp_up_w[li], arr) + elif suffix == "mlp.down_proj.weight": + load(self.weights.mlp_down_w[li], arr) + + # tied lm_head fallback: if missing, copy embed -> out_embed + if not lm_head_loaded: + out_ptr = self.lib.tensorGetData(self.weights.out_embed) + in_ptr = self.lib.tensorGetData(self.weights.in_embed) + if out_ptr and in_ptr: + # copy bytes + if _dtype_is_bf16(dtype) or _dtype_is_f16(dtype): + nbytes = voc * hs * 2 + elif _dtype_is_f32(dtype): + nbytes = voc * hs * 4 + else: + raise RuntimeError("Unsupported dtype for tied lm_head copy") + ctypes.memmove(out_ptr, in_ptr, nbytes) + + def __del__(self): + m = getattr(self, "model", None) + if m: + try: + self.lib.llaisysQwen2ModelDestroy(m) + except Exception: pass + self.model = None def generate( self, inputs: Sequence[int], - max_new_tokens: int = None, + max_new_tokens: Optional[int] = None, top_k: int = 1, top_p: float = 0.8, temperature: float = 0.8, ): + if max_new_tokens is None: + max_new_tokens = 128 + + # Greedy for test: top_k=1, top_p=1.0, temperature=1.0 + # We ignore sampling params intentionally. + tokens = [int(x) for x in inputs] + eos = int(self.meta.end_token) + + # # prefill once + # arr = (ctypes.c_int64 * len(tokens))(*tokens) + # nxt = int(self.lib.llaisysQwen2ModelInfer(self.model, arr, len(tokens))) + # tokens.append(nxt) + + # # decode + # for _ in range(max_new_tokens - 1): + # last = tokens[-1] + # arr1 = (ctypes.c_int64 * 1)(last) + # nxt = int(self.lib.llaisysQwen2ModelInfer(self.model, arr1, 1)) + # tokens.append(nxt) + # 先把 prompt 一次性喂进去,拿到第一个 next token + arr = (ctypes.c_int64 * len(tokens))(*tokens) + nxt = int(self.lib.llaisysQwen2ModelInfer(self.model, arr, len(tokens))) + tokens.append(nxt) + if nxt == eos: + return tokens + + # 后续增量生成:遇到 eos 立刻停止(对齐 HF generate) + for _ in range(max_new_tokens - 1): + last = tokens[-1] + arr1 = (ctypes.c_int64 * 1)(last) + nxt = int(self.lib.llaisysQwen2ModelInfer(self.model, arr1, 1)) + tokens.append(nxt) + if nxt == eos: + break + + return tokens + + +# from typing import Sequence +# from ..libllaisys import LIB_LLAISYS +# from ..libllaisys import DeviceType + +# from pathlib import Path +# import safetensors + + +# class Qwen2: + +# def __init__(self, model_path, device: DeviceType = DeviceType.CPU): +# # TODO: Implement model constructor + +# model_path = Path(model_path) + +# for file in sorted(model_path.glob("*.safetensors")): +# data_ = safetensors.safe_open(file, framework="numpy", device="cpu") +# for name_ in data_.keys(): +# ## TODO: load the model weights +# pass + +# def generate( +# self, +# inputs: Sequence[int], +# max_new_tokens: int = None, +# top_k: int = 1, +# top_p: float = 0.8, +# temperature: float = 0.8, +# ): - # TODO: Implement generate function +# # TODO: Implement generate function - return [] +# return [] diff --git a/src/llaisys/models/qwen.cc b/src/llaisys/models/qwen.cc new file mode 100644 index 000000000..207bc638e --- /dev/null +++ b/src/llaisys/models/qwen.cc @@ -0,0 +1,479 @@ +#include +#include +#include +#include +#include + +#include "llaisys/models/qwen2.h" +// #include "tensor/tensor.hpp" +// #include "llaisys/ops.hpp" +// #include "llaisys/utils/check.hpp" +// #include "llaisys/core/context.hpp" +#include "../../tensor/tensor.hpp" +#include "../../ops/add/op.hpp" +#include "../../ops/argmax/op.hpp" +#include "../../ops/embedding/op.hpp" +#include "../../ops/linear/op.hpp" +#include "../../ops/rearrange/op.hpp" +#include "../../ops/rms_norm/op.hpp" +#include "../../ops/rope/op.hpp" +#include "../../ops/self_attention/op.hpp" +#include "../../ops/swiglu/op.hpp" +#include "../../utils/check.hpp" +#include "../../core/context/context.hpp" +#include "../llaisys_tensor.hpp" + + + +// ===== Debug helpers ===== +#ifndef LLAISYS_QWEN2_DEBUG +#define LLAISYS_QWEN2_DEBUG 0 +#endif + +#if LLAISYS_QWEN2_DEBUG +#include +#define DBG_MSG(x) do { std::cout << "[DBG] " << x << std::endl; } while (0) +#define DBG_TENSOR(name, t) \ + do { \ + std::cout << "\n[DBG_TENSOR] " << (name) << std::endl; \ + (t)->debug(); \ + } while (0) +#else +#define DBG_MSG(x) do {} while (0) +#define DBG_TENSOR(name, t) do {} while (0) +#endif + + + +namespace { + +using llaisys::tensor_t; + +static inline tensor_t make_tensor(const std::vector& shape, + llaisysDataType_t dtype, + llaisysDeviceType_t dev, + int dev_id) { + return llaisys::Tensor::create(shape, dtype, dev, dev_id); +} + +// 强健:确保传给 ops 的张量是 contiguous(因为 ops 不看 strides) +static inline void require_contiguous(const tensor_t& t, const char* name) { + ASSERT(t->isContiguous(), name); +} + +struct Qwen2KV { + tensor_t k; // [maxseq, nkvh, dh] + tensor_t v; // [maxseq, nkvh, dh] +}; + +struct Qwen2Tmp { + tensor_t tok_i64; // [max_L] + tensor_t pos_i64; // [max_L] + + // residual stream(我们用 swap 避免 add 后 memcpy) + tensor_t x_a; // [max_L, hs] + tensor_t x_b; // [max_L, hs] scratch + + tensor_t h; // [max_L, hs] + tensor_t y; // [max_L, hs] + + tensor_t q1, k1, v1; // [max_L, nh*dh] / [max_L, nkvh*dh] + tensor_t q, k, v; // [max_L, nh, dh] / [max_L, nkvh, dh] + tensor_t q_rope, k_rope; + + tensor_t attn_val; // [max_L, nh, dh] + tensor_t attn_merge; // [max_L, nh*dh] + tensor_t attn_out; // [max_L, hs] + + tensor_t mlp_in; // [max_L, hs] + tensor_t gate, up, act; // [max_L, di] + tensor_t mlp_out; // [max_L, hs] + + tensor_t logits; // [1, voc] + tensor_t max_idx; // [1] i64 + tensor_t max_val; // [1] dt +}; + +struct Qwen2Impl { + LlaisysQwen2Meta meta{}; + llaisysDeviceType_t device{}; + int device_id{0}; + + LlaisysQwen2Weights weights{}; + + std::vector kv; + Qwen2Tmp tmp{}; + + size_t cur_pos{0}; +}; + +static void alloc_weights(Qwen2Impl* m) { + auto& meta = m->meta; + auto& w = m->weights; + + const size_t nlayer = meta.nlayer; + const size_t hs = meta.hs; + const size_t nh = meta.nh; + const size_t nkvh = meta.nkvh; + const size_t dh = meta.dh; + const size_t di = meta.di; + const size_t voc = meta.voc; + + const auto dt = meta.dtype; + const auto dev = m->device; + const int dev_id = m->device_id; + + w.in_embed = new LlaisysTensor{ make_tensor({voc, hs}, dt, dev, dev_id) }; + w.out_embed = new LlaisysTensor{ make_tensor({voc, hs}, dt, dev, dev_id) }; + w.out_norm_w = new LlaisysTensor{ make_tensor({hs}, dt, dev, dev_id) }; + + w.attn_norm_w = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + w.attn_q_w = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + w.attn_q_b = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + w.attn_k_w = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + w.attn_k_b = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + w.attn_v_w = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + w.attn_v_b = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + w.attn_o_w = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + + w.mlp_norm_w = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + w.mlp_gate_w = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + w.mlp_up_w = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + w.mlp_down_w = (llaisysTensor_t*)std::calloc(nlayer, sizeof(llaisysTensor_t)); + + for (size_t i = 0; i < nlayer; ++i) { + w.attn_norm_w[i] = new LlaisysTensor{ make_tensor({hs}, meta.dtype, dev, dev_id) }; + + w.attn_q_w[i] = new LlaisysTensor{ make_tensor({nh*dh, hs}, meta.dtype, dev, dev_id) }; + w.attn_q_b[i] = new LlaisysTensor{ make_tensor({nh*dh}, meta.dtype, dev, dev_id) }; + + w.attn_k_w[i] = new LlaisysTensor{ make_tensor({nkvh*dh, hs}, meta.dtype, dev, dev_id) }; + w.attn_k_b[i] = new LlaisysTensor{ make_tensor({nkvh*dh}, meta.dtype, dev, dev_id) }; + + w.attn_v_w[i] = new LlaisysTensor{ make_tensor({nkvh*dh, hs}, meta.dtype, dev, dev_id) }; + w.attn_v_b[i] = new LlaisysTensor{ make_tensor({nkvh*dh}, meta.dtype, dev, dev_id) }; + + w.attn_o_w[i] = new LlaisysTensor{ make_tensor({hs, nh*dh}, meta.dtype, dev, dev_id) }; + + w.mlp_norm_w[i] = new LlaisysTensor{ make_tensor({hs}, meta.dtype, dev, dev_id) }; + w.mlp_gate_w[i] = new LlaisysTensor{ make_tensor({di, hs}, meta.dtype, dev, dev_id) }; + w.mlp_up_w[i] = new LlaisysTensor{ make_tensor({di, hs}, meta.dtype, dev, dev_id) }; + w.mlp_down_w[i] = new LlaisysTensor{ make_tensor({hs, di}, meta.dtype, dev, dev_id) }; + } +} + +static void alloc_kv_tmp(Qwen2Impl* m, size_t max_L) { + auto& meta = m->meta; + const auto dt = meta.dtype; + const auto dev = m->device; + const int dev_id = m->device_id; + + const size_t nlayer = meta.nlayer; + const size_t hs = meta.hs; + const size_t nh = meta.nh; + const size_t nkvh = meta.nkvh; + const size_t dh = meta.dh; + const size_t di = meta.di; + const size_t voc = meta.voc; + + // KV cache:用模型 dtype(与你 ops 一致) + m->kv.resize(nlayer); + for (size_t i = 0; i < nlayer; ++i) { + m->kv[i].k = make_tensor({meta.maxseq, nkvh, dh}, dt, dev, dev_id); + m->kv[i].v = make_tensor({meta.maxseq, nkvh, dh}, dt, dev, dev_id); + std::memset(m->kv[i].k->data(), 0, m->kv[i].k->numel() * m->kv[i].k->elementSize()); + std::memset(m->kv[i].v->data(), 0, m->kv[i].v->numel() * m->kv[i].v->elementSize()); + } + + auto& t = m->tmp; + t.tok_i64 = make_tensor({max_L}, LLAISYS_DTYPE_I64, dev, dev_id); + t.pos_i64 = make_tensor({max_L}, LLAISYS_DTYPE_I64, dev, dev_id); + + t.x_a = make_tensor({max_L, hs}, dt, dev, dev_id); + t.x_b = make_tensor({max_L, hs}, dt, dev, dev_id); + + t.h = make_tensor({max_L, hs}, dt, dev, dev_id); + t.y = make_tensor({max_L, hs}, dt, dev, dev_id); + + t.q1 = make_tensor({max_L, nh*dh}, dt, dev, dev_id); + t.k1 = make_tensor({max_L, nkvh*dh}, dt, dev, dev_id); + t.v1 = make_tensor({max_L, nkvh*dh}, dt, dev, dev_id); + + t.q = make_tensor({max_L, nh, dh}, dt, dev, dev_id); + t.k = make_tensor({max_L, nkvh, dh}, dt, dev, dev_id); + t.v = make_tensor({max_L, nkvh, dh}, dt, dev, dev_id); + t.q_rope = make_tensor({max_L, nh, dh}, dt, dev, dev_id); + t.k_rope = make_tensor({max_L, nkvh, dh}, dt, dev, dev_id); + + t.attn_val = make_tensor({max_L, nh, dh}, dt, dev, dev_id); + t.attn_merge = make_tensor({max_L, nh*dh}, dt, dev, dev_id); + t.attn_out = make_tensor({max_L, hs}, dt, dev, dev_id); + + t.mlp_in = make_tensor({max_L, hs}, dt, dev, dev_id); + t.gate = make_tensor({max_L, di}, dt, dev, dev_id); + t.up = make_tensor({max_L, di}, dt, dev, dev_id); + t.act = make_tensor({max_L, di}, dt, dev, dev_id); + t.mlp_out = make_tensor({max_L, hs}, dt, dev, dev_id); + + t.logits = make_tensor({1, voc}, dt, dev, dev_id); + t.max_idx = make_tensor({1}, LLAISYS_DTYPE_I64, dev, dev_id); + t.max_val = make_tensor({1}, dt, dev, dev_id); +} + +static void free_weights(Qwen2Impl* m) { + auto& w = m->weights; + auto del_t = [](llaisysTensor_t t){ if (t) delete t; }; + + del_t(w.in_embed); + del_t(w.out_embed); + del_t(w.out_norm_w); + + for (size_t i = 0; i < m->meta.nlayer; ++i) { + del_t(w.attn_norm_w[i]); + del_t(w.attn_q_w[i]); + del_t(w.attn_q_b[i]); + del_t(w.attn_k_w[i]); + del_t(w.attn_k_b[i]); + del_t(w.attn_v_w[i]); + del_t(w.attn_v_b[i]); + del_t(w.attn_o_w[i]); + + del_t(w.mlp_norm_w[i]); + del_t(w.mlp_gate_w[i]); + del_t(w.mlp_up_w[i]); + del_t(w.mlp_down_w[i]); + } + + std::free(w.attn_norm_w); + std::free(w.attn_q_w); + std::free(w.attn_q_b); + std::free(w.attn_k_w); + std::free(w.attn_k_b); + std::free(w.attn_v_w); + std::free(w.attn_v_b); + std::free(w.attn_o_w); + + std::free(w.mlp_norm_w); + std::free(w.mlp_gate_w); + std::free(w.mlp_up_w); + std::free(w.mlp_down_w); +} + +} // namespace + +extern "C" { + +struct LlaisysQwen2Model { + Qwen2Impl* impl; +}; + +__export struct LlaisysQwen2Model * +llaisysQwen2ModelCreate(const LlaisysQwen2Meta *meta, + llaisysDeviceType_t device, + int* /*device_ids*/, int /*ndevice*/) noexcept { + try { + if (meta == nullptr) return nullptr; + + auto* model = new LlaisysQwen2Model; + model->impl = new Qwen2Impl; + auto* m = model->impl; + + m->meta = *meta; + m->device = device; + m->device_id = 0; + m->cur_pos = 0; + + llaisys::core::context().setDevice(device, 0); + + alloc_weights(m); + + // 保险:tmp 直接按 maxseq 分配,避免 prompt 超出临时 buffer + alloc_kv_tmp(m, meta->maxseq); + + return model; + } catch (...) { + return nullptr; + } +} + +__export void +llaisysQwen2ModelDestroy(struct LlaisysQwen2Model * model) noexcept { + try { + if (!model) return; + if (model->impl) { + free_weights(model->impl); + delete model->impl; + } + delete model; + } catch (...) {} +} + +__export struct LlaisysQwen2Weights * +llaisysQwen2ModelWeights(struct LlaisysQwen2Model * model) noexcept { + return model ? &model->impl->weights : nullptr; +} + +__export int64_t +llaisysQwen2ModelInfer(struct LlaisysQwen2Model * model, int64_t * token_ids, size_t ntoken) noexcept { + // CHECK_ARGUMENT(model && model->impl, "model is null"); + // CHECK_ARGUMENT(token_ids != nullptr, "token_ids is null"); + // CHECK_ARGUMENT(ntoken > 0, "ntoken must be > 0"); + try { + if (!model || !model->impl) return -1; + if (token_ids == nullptr) return -1; + if (ntoken == 0) return -1; + + auto* m = model->impl; + auto& meta = m->meta; + auto& w = m->weights; + auto& t = m->tmp; + + //CHECK_ARGUMENT(m->cur_pos + ntoken < meta.maxseq, "sequence length exceeds maxseq"); + if (m->cur_pos + ntoken >= meta.maxseq) return -1; + + const size_t dh = meta.dh; + const float scale = 1.0f / std::sqrt((float)dh); + + // ---- 写 tok/pos ---- + std::memcpy(t.tok_i64->data(), token_ids, ntoken * sizeof(int64_t)); + { + int64_t* p = reinterpret_cast(t.pos_i64->data()); + for (size_t i = 0; i < ntoken; ++i) p[i] = (int64_t)(m->cur_pos + i); + } + + // 只做 dim0 slice:通常仍 contiguous + tensor_t tok_L = t.tok_i64->slice(0, 0, ntoken); + tensor_t pos_L = t.pos_i64->slice(0, 0, ntoken); + + tensor_t x_L = t.x_a->slice(0, 0, ntoken); + tensor_t xb_L = t.x_b->slice(0, 0, ntoken); + tensor_t h_L = t.h->slice(0, 0, ntoken); + + require_contiguous(tok_L, "tok_L must be contiguous"); + require_contiguous(pos_L, "pos_L must be contiguous"); + require_contiguous(x_L, "x_L must be contiguous"); + + // ---- embedding ---- + llaisys::ops::embedding(x_L, tok_L, w.in_embed->tensor); + + // after embedding(x_L, tok_L, in_embed) + if (LLAISYS_QWEN2_DEBUG) { + auto x0 = x_L->slice(0, 0, 1); // [1, hs] + DBG_TENSOR("x after embedding (first token)", x0); + } + + // ---- layers ---- + for (size_t layer = 0; layer < meta.nlayer; ++layer) { + // attn rmsnorm + llaisys::ops::rms_norm(h_L, x_L, w.attn_norm_w[layer]->tensor, meta.epsilon); + + // qkv + tensor_t q1_L = t.q1->slice(0, 0, ntoken); + tensor_t k1_L = t.k1->slice(0, 0, ntoken); + tensor_t v1_L = t.v1->slice(0, 0, ntoken); + + llaisys::ops::linear(q1_L, h_L, w.attn_q_w[layer]->tensor, w.attn_q_b[layer]->tensor); + llaisys::ops::linear(k1_L, h_L, w.attn_k_w[layer]->tensor, w.attn_k_b[layer]->tensor); + llaisys::ops::linear(v1_L, h_L, w.attn_v_w[layer]->tensor, w.attn_v_b[layer]->tensor); + + // rearrange to [L, head, dh] (你们是 memcpy,所以 numel 必须相同) + tensor_t q_L = t.q->slice(0, 0, ntoken); + tensor_t k_L = t.k->slice(0, 0, ntoken); + tensor_t v_L = t.v->slice(0, 0, ntoken); + + llaisys::ops::rearrange(q_L, q1_L); + llaisys::ops::rearrange(k_L, k1_L); + llaisys::ops::rearrange(v_L, v1_L); + + // rope + tensor_t q_rope_L = t.q_rope->slice(0, 0, ntoken); + tensor_t k_rope_L = t.k_rope->slice(0, 0, ntoken); + llaisys::ops::rope(q_rope_L, q_L, pos_L, meta.theta); + llaisys::ops::rope(k_rope_L, k_L, pos_L, meta.theta); + + // after rope(q_rope_L, q_L, pos_L, theta) and rope(k_rope_L, ...) + if (LLAISYS_QWEN2_DEBUG && layer == 0) { + auto q0 = q_rope_L->slice(0, 0, 1); // [1, nh, dh] + auto k0 = k_rope_L->slice(0, 0, 1); // [1, nkvh, dh] + DBG_TENSOR("layer0 q_rope (first token)", q0); + DBG_TENSOR("layer0 k_rope (first token)", k0); + } + + + // write KV cache + tensor_t kc_win = m->kv[layer].k->slice(0, m->cur_pos, m->cur_pos + ntoken); + tensor_t vc_win = m->kv[layer].v->slice(0, m->cur_pos, m->cur_pos + ntoken); + // 注意:slice dim0,连续 OK + llaisys::ops::rearrange(kc_win, k_rope_L); + llaisys::ops::rearrange(vc_win, v_L); + + // history KV for attention + const size_t total_len = m->cur_pos + ntoken; + tensor_t k_hist = m->kv[layer].k->slice(0, 0, total_len); + tensor_t v_hist = m->kv[layer].v->slice(0, 0, total_len); + + tensor_t attn_val_L = t.attn_val->slice(0, 0, ntoken); + llaisys::ops::self_attention(attn_val_L, q_rope_L, k_hist, v_hist, scale); + + // merge heads + tensor_t attn_merge_L = t.attn_merge->slice(0, 0, ntoken); + llaisys::ops::rearrange(attn_merge_L, attn_val_L); + + // o proj + tensor_t attn_out_L = t.attn_out->slice(0, 0, ntoken); + llaisys::ops::linear(attn_out_L, attn_merge_L, w.attn_o_w[layer]->tensor, nullptr); + + // residual: xb = x + attn_out; swap(x, xb) + llaisys::ops::add(xb_L, x_L, attn_out_L); + std::swap(t.x_a, t.x_b); + x_L = t.x_a->slice(0, 0, ntoken); + xb_L = t.x_b->slice(0, 0, ntoken); + + // MLP + tensor_t mlp_in_L = t.mlp_in->slice(0, 0, ntoken); + llaisys::ops::rms_norm(mlp_in_L, x_L, w.mlp_norm_w[layer]->tensor, meta.epsilon); + + tensor_t gate_L = t.gate->slice(0, 0, ntoken); + tensor_t up_L = t.up->slice(0, 0, ntoken); + llaisys::ops::linear(gate_L, mlp_in_L, w.mlp_gate_w[layer]->tensor, nullptr); + llaisys::ops::linear(up_L, mlp_in_L, w.mlp_up_w[layer]->tensor, nullptr); + + tensor_t act_L = t.act->slice(0, 0, ntoken); + llaisys::ops::swiglu(act_L, gate_L, up_L); + + tensor_t mlp_out_L = t.mlp_out->slice(0, 0, ntoken); + llaisys::ops::linear(mlp_out_L, act_L, w.mlp_down_w[layer]->tensor, nullptr); + + // residual: xb = x + mlp_out; swap + llaisys::ops::add(xb_L, x_L, mlp_out_L); + std::swap(t.x_a, t.x_b); + x_L = t.x_a->slice(0, 0, ntoken); + xb_L = t.x_b->slice(0, 0, ntoken); + } + + // final norm + tensor_t y_L = t.y->slice(0, 0, ntoken); + llaisys::ops::rms_norm(y_L, x_L, w.out_norm_w->tensor, meta.epsilon); + + // last token hidden: [1, hs] + tensor_t y_last = y_L->slice(0, ntoken - 1, ntoken); + + // logits: [1, voc] + llaisys::ops::linear(t.logits, y_last, w.out_embed->tensor, nullptr); + if (LLAISYS_QWEN2_DEBUG) { + DBG_TENSOR("logits (shape [1, voc])", t.logits); + } + + // argmax over logits numel (你们 argmax 扫全 numel) + llaisys::ops::argmax(t.max_idx, t.max_val, t.logits); + + int64_t next = *reinterpret_cast(t.max_idx->data()); + + m->cur_pos += ntoken; + return next; + } catch (...) { + return -1; + } +} + +} // extern "C" diff --git a/src/ops/argmax/op.cpp b/src/ops/argmax/op.cpp index 6dc37d426..e1e2b8291 100644 --- a/src/ops/argmax/op.cpp +++ b/src/ops/argmax/op.cpp @@ -2,6 +2,83 @@ namespace llaisys::ops { void argmax(tensor_t max_idx, tensor_t max_val, tensor_t vals) { - TO_BE_IMPLEMENTED(); + + // 获取张量数据指针 + auto idx_data = max_idx->data(); + auto val_data = max_val->data(); + auto vals_data = vals->data(); + + // 获取维度信息 + size_t numel = vals->numel(); + auto dtype = vals->dtype(); + + // 根据数据类型进行处理 + switch (static_cast(dtype)) { + case LLAISYS_DTYPE_F32: { + auto* vals_ptr = reinterpret_cast(vals_data); + auto* max_val_ptr = reinterpret_cast(val_data); + auto* max_idx_ptr = reinterpret_cast(idx_data); + + float max_val = vals_ptr[0]; + int64_t max_idx_val = 0; + + for (size_t i = 1; i < numel; i++) { + if (vals_ptr[i] > max_val) { + max_val = vals_ptr[i]; + max_idx_val = i; + } + } + + *max_val_ptr = max_val; + *max_idx_ptr = max_idx_val; + break; + } + + case LLAISYS_DTYPE_F16: { + auto* vals_ptr = reinterpret_cast(vals_data); + auto* max_val_ptr = reinterpret_cast(val_data); + auto* max_idx_ptr = reinterpret_cast(idx_data); + + float max_val = llaisys::utils::_f16_to_f32(vals_ptr[0]); + int64_t max_idx_val = 0; + + for (size_t i = 1; i < numel; i++) { + float curr_val = llaisys::utils::_f16_to_f32(vals_ptr[i]); + if (curr_val > max_val) { + max_val = curr_val; + max_idx_val = i; + } + } + + *max_val_ptr = llaisys::utils::_f32_to_f16(max_val); + *max_idx_ptr = max_idx_val; + break; + } + + case LLAISYS_DTYPE_BF16: { + auto* vals_ptr = reinterpret_cast(vals_data); + auto* max_val_ptr = reinterpret_cast(val_data); + auto* max_idx_ptr = reinterpret_cast(idx_data); + + float max_val = llaisys::utils::_bf16_to_f32(vals_ptr[0]); + int64_t max_idx_val = 0; + + for (size_t i = 1; i < numel; i++) { + float curr_val = llaisys::utils::_bf16_to_f32(vals_ptr[i]); + if (curr_val > max_val) { + max_val = curr_val; + max_idx_val = i; + } + } + + *max_val_ptr = llaisys::utils::_f32_to_bf16(max_val); + *max_idx_ptr = max_idx_val; + break; + } + + default: + EXCEPTION_UNSUPPORTED_DATATYPE(static_cast(dtype)); + } + } } // namespace llaisys::ops diff --git a/src/ops/embedding/op.cpp b/src/ops/embedding/op.cpp index 84b9a5d06..3d7737f49 100644 --- a/src/ops/embedding/op.cpp +++ b/src/ops/embedding/op.cpp @@ -2,6 +2,77 @@ namespace llaisys::ops { void embedding(tensor_t out, tensor_t index, tensor_t weight) { - TO_BE_IMPLEMENTED(); + //TO_BE_IMPLEMENTED(); + // 获取张量数据指针 + auto out_data = out->data(); + auto index_data = index->data(); + auto weight_data = weight->data(); + + // 获取维度信息 + auto out_shape = out->shape(); + auto index_shape = index->shape(); + auto weight_shape = weight->shape(); + + size_t num_indices = index_shape[0]; + size_t embedding_dim = weight_shape[1]; + auto dtype = weight->dtype(); + + // 转换索引数据指针 + auto* index_ptr = reinterpret_cast(index_data); + + // 根据数据类型进行处理 + switch (static_cast(dtype)) { + case LLAISYS_DTYPE_F32: { + auto* weight_ptr = reinterpret_cast(weight_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < num_indices; i++) { + int64_t idx = index_ptr[i]; + const float* src_row = weight_ptr + idx * embedding_dim; + float* dst_row = out_ptr + i * embedding_dim; + + for (size_t j = 0; j < embedding_dim; j++) { + dst_row[j] = src_row[j]; + } + } + break; + } + + case LLAISYS_DTYPE_F16: { + auto* weight_ptr = reinterpret_cast(weight_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < num_indices; i++) { + int64_t idx = index_ptr[i]; + const llaisys::fp16_t* src_row = weight_ptr + idx * embedding_dim; + llaisys::fp16_t* dst_row = out_ptr + i * embedding_dim; + + for (size_t j = 0; j < embedding_dim; j++) { + dst_row[j] = src_row[j]; + } + } + break; + } + + case LLAISYS_DTYPE_BF16: { + auto* weight_ptr = reinterpret_cast(weight_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < num_indices; i++) { + int64_t idx = index_ptr[i]; + const llaisys::bf16_t* src_row = weight_ptr + idx * embedding_dim; + llaisys::bf16_t* dst_row = out_ptr + i * embedding_dim; + + for (size_t j = 0; j < embedding_dim; j++) { + dst_row[j] = src_row[j]; + } + } + break; + } + + default: + EXCEPTION_UNSUPPORTED_DATATYPE(static_cast(dtype)); + } + } } // namespace llaisys::ops diff --git a/src/ops/linear/op.cpp b/src/ops/linear/op.cpp index 97d1f8655..61ae0385c 100644 --- a/src/ops/linear/op.cpp +++ b/src/ops/linear/op.cpp @@ -2,6 +2,239 @@ namespace llaisys::ops { void linear(tensor_t out, tensor_t in, tensor_t weight, tensor_t bias) { - TO_BE_IMPLEMENTED(); + //TO_BE_IMPLEMENTED(); + + // --- 基本检查 --- + if (!out || !in || !weight) { + throw std::runtime_error("linear: out/in/weight must not be null"); + } + + auto out_shape = out->shape(); + auto in_shape = in->shape(); + auto w_shape = weight->shape(); + + if (in_shape.size() != 2 || w_shape.size() != 2 || out_shape.size() != 2) { + throw std::runtime_error("linear: only supports 2D tensors (batch, features)"); + } + + const size_t B = in_shape[0]; + const size_t K = in_shape[1]; // in_features + const size_t OC = w_shape[0]; // out_features ✅ 正确 + const size_t WK = w_shape[1]; // should equal K + + if (WK != K) { + throw std::runtime_error("linear: weight shape mismatch, expected w.shape[1] == in.shape[1]"); + } + if (out_shape[0] != B || out_shape[1] != OC) { + throw std::runtime_error("linear: out shape mismatch, expected out.shape == (B, out_features)"); + } + + const bool has_bias = (bias != nullptr); + if (has_bias) { + auto b_shape = bias->shape(); + if (b_shape.size() != 1 || b_shape[0] != OC) { + throw std::runtime_error("linear: bias shape mismatch, expected bias.shape == (out_features,)"); + } + } + + // dtype 建议也校验一致(至少 in/weight/out 一致) + auto dtype = in->dtype(); + if (weight->dtype() != dtype || out->dtype() != dtype) { + throw std::runtime_error("linear: dtype mismatch among in/weight/out"); + } + if (has_bias && bias->dtype() != dtype) { + throw std::runtime_error("linear: bias dtype mismatch"); + } + + + + + // // 检查 bias 是否提供 + // bool has_bias = (bias != nullptr); + // const std::byte* bias_data = has_bias ? bias->data() : nullptr; + + // // 获取维度信息 + // auto out_shape = out->shape(); + // auto in_shape = in->shape(); + // auto weight_shape = weight->shape(); + + // size_t batch_size = in_shape[0]; + // size_t in_features = in_shape[1]; + // size_t out_features = weight_shape[1]; + // auto dtype = in->dtype(); + + auto in_data = in->data(); + auto w_data = weight->data(); + auto out_data = out->data(); + const std::byte* b_data = has_bias ? bias->data() : nullptr; + + // 根据数据类型进行处理 + switch (static_cast(dtype)) { + case LLAISYS_DTYPE_F32: { + // auto* in_ptr = reinterpret_cast(in_data); + // auto* weight_ptr = reinterpret_cast(weight_data); + // auto* out_ptr = reinterpret_cast(out_data); + // const float* bias_ptr = has_bias ? reinterpret_cast(bias_data) : nullptr; + const float* x = reinterpret_cast(in_data); + const float* w = reinterpret_cast(w_data); + float* y = reinterpret_cast(out_data); + const float* b = has_bias ? reinterpret_cast(b_data) : nullptr; + + // for (size_t b = 0; b < batch_size; b++) { + // const float* batch_in = in_ptr + b * in_features; + // float* batch_out = out_ptr + b * out_features; + for (size_t n = 0; n < B; ++n) { + const float* x_row = x + n * K; + float* y_row = y + n * OC; + + // // 初始化输出(如果有bias则加bias,否则为0) + // if (has_bias) { + // for (size_t o = 0; o < out_features; o++) { + // batch_out[o] = bias_ptr[o]; + // } + // } else { + // for (size_t o = 0; o < out_features; o++) { + // batch_out[o] = 0.0f; + // } + // } + + // // 矩阵乘法: out = in * weight^T + // for (size_t i = 0; i < in_features; i++) { + // float in_val = batch_in[i]; + // for (size_t o = 0; o < out_features; o++) { + // // weight 是 [out_features, in_features] + // // 需要 weight[o, i] 对应 weight^T[i, o] + // float weight_val = weight_ptr[o * in_features + i]; + // batch_out[o] += in_val * weight_val; + // } + // } + for (size_t o = 0; o < OC; ++o) { + float acc = has_bias ? b[o] : 0.0f; + const float* w_row = w + o * K; // w[o, :] + for (size_t i = 0; i < K; ++i) { + acc += x_row[i] * w_row[i]; + } + y_row[o] = acc; + } + + } + break; + } + + case LLAISYS_DTYPE_F16: { + // auto* in_ptr = reinterpret_cast(in_data); + // auto* weight_ptr = reinterpret_cast(weight_data); + // auto* out_ptr = reinterpret_cast(out_data); + // const llaisys::fp16_t* bias_ptr = has_bias ? reinterpret_cast(bias_data) : nullptr; + const llaisys::fp16_t* x = reinterpret_cast(in_data); + const llaisys::fp16_t* w = reinterpret_cast(w_data); + llaisys::fp16_t* y = reinterpret_cast(out_data); + const llaisys::fp16_t* b = has_bias ? reinterpret_cast(b_data) : nullptr; + + // for (size_t b = 0; b < batch_size; b++) { + // const llaisys::fp16_t* batch_in = in_ptr + b * in_features; + // llaisys::fp16_t* batch_out = out_ptr + b * out_features; + for (size_t n = 0; n < B; ++n) { + const llaisys::fp16_t* x_row = x + n * K; + llaisys::fp16_t* y_row = y + n * OC; + + // // 初始化输出 + // if (has_bias) { + // for (size_t o = 0; o < out_features; o++) { + // batch_out[o] = bias_ptr[o]; + // } + // } else { + // for (size_t o = 0; o < out_features; o++) { + // batch_out[o] = llaisys::utils::_f32_to_f16(0.0f); // 对于 fp16_t + // } + // } + + // // 矩阵乘法 + // for (size_t i = 0; i < in_features; i++) { + // float in_val = llaisys::utils::_f16_to_f32(batch_in[i]); + // for (size_t o = 0; o < out_features; o++) { + // float weight_val = llaisys::utils::_f16_to_f32(weight_ptr[o * in_features + i]); + // float out_val = llaisys::utils::_f16_to_f32(batch_out[o]); + // out_val += in_val * weight_val; + // batch_out[o] = llaisys::utils::_f32_to_f16(out_val); + // } + // } + + for (size_t o = 0; o < OC; ++o) { + float acc = has_bias ? llaisys::utils::_f16_to_f32(b[o]) : 0.0f; + const llaisys::fp16_t* w_row = w + o * K; + for (size_t i = 0; i < K; ++i) { + float xv = llaisys::utils::_f16_to_f32(x_row[i]); + float wv = llaisys::utils::_f16_to_f32(w_row[i]); + acc += xv * wv; + } + y_row[o] = llaisys::utils::_f32_to_f16(acc); + } + } + break; + } + + case LLAISYS_DTYPE_BF16: { + // auto* in_ptr = reinterpret_cast(in_data); + // auto* weight_ptr = reinterpret_cast(weight_data); + // auto* out_ptr = reinterpret_cast(out_data); + // const llaisys::bf16_t* bias_ptr = has_bias ? reinterpret_cast(bias_data) : nullptr; + const llaisys::bf16_t* x = reinterpret_cast(in_data); + const llaisys::bf16_t* w = reinterpret_cast(w_data); + llaisys::bf16_t* y = reinterpret_cast(out_data); + const llaisys::bf16_t* b = has_bias ? reinterpret_cast(b_data) : nullptr; + + + + // for (size_t b = 0; b < batch_size; b++) { + // const llaisys::bf16_t* batch_in = in_ptr + b * in_features; + // llaisys::bf16_t* batch_out = out_ptr + b * out_features; + + for (size_t n = 0; n < B; ++n) { + const llaisys::bf16_t* x_row = x + n * K; + llaisys::bf16_t* y_row = y + n * OC; + + // // 初始化输出 + // if (has_bias) { + // for (size_t o = 0; o < out_features; o++) { + // batch_out[o] = bias_ptr[o]; + // } + // } else { + // for (size_t o = 0; o < out_features; o++) { + // //batch_out[o] = llaisys::bf16_t(0.0f); + // batch_out[o] = llaisys::utils::_f32_to_bf16(0.0f); + // } + // } + + // // 矩阵乘法 + // for (size_t i = 0; i < in_features; i++) { + // float in_val = llaisys::utils::_bf16_to_f32(batch_in[i]); + // for (size_t o = 0; o < out_features; o++) { + // float weight_val = llaisys::utils::_bf16_to_f32(weight_ptr[o * in_features + i]); + // float out_val = llaisys::utils::_bf16_to_f32(batch_out[o]); + // out_val += in_val * weight_val; + // batch_out[o] = llaisys::utils::_f32_to_bf16(out_val); + // } + // } + + for (size_t o = 0; o < OC; ++o) { + float acc = has_bias ? llaisys::utils::_bf16_to_f32(b[o]) : 0.0f; + const llaisys::bf16_t* w_row = w + o * K; + for (size_t i = 0; i < K; ++i) { + float xv = llaisys::utils::_bf16_to_f32(x_row[i]); + float wv = llaisys::utils::_bf16_to_f32(w_row[i]); + acc += xv * wv; + } + y_row[o] = llaisys::utils::_f32_to_bf16(acc); + } + + + } + break; + } + + default: + EXCEPTION_UNSUPPORTED_DATATYPE(static_cast(dtype)); + } } } // namespace llaisys::ops diff --git a/src/ops/rearrange/op.cpp b/src/ops/rearrange/op.cpp index 017a6ae59..2ad0413a9 100644 --- a/src/ops/rearrange/op.cpp +++ b/src/ops/rearrange/op.cpp @@ -2,6 +2,61 @@ namespace llaisys::ops { void rearrange(tensor_t out, tensor_t in) { - TO_BE_IMPLEMENTED(); + //TO_BE_IMPLEMENTED(); + // 获取张量数据指针 + auto out_data = out->data(); + auto in_data = in->data(); + + // 获取总元素数 + size_t total_elements = in->numel(); + auto dtype = in->dtype(); + + // 根据数据类型进行处理 + switch (static_cast(dtype)) { + case LLAISYS_DTYPE_F32: { + auto* in_ptr = reinterpret_cast(in_data); + auto* out_ptr = reinterpret_cast(out_data); + + // 简单复制所有元素 + for (size_t i = 0; i < total_elements; i++) { + out_ptr[i] = in_ptr[i]; + } + break; + } + + case LLAISYS_DTYPE_F16: { + auto* in_ptr = reinterpret_cast(in_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < total_elements; i++) { + out_ptr[i] = in_ptr[i]; + } + break; + } + + case LLAISYS_DTYPE_BF16: { + auto* in_ptr = reinterpret_cast(in_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < total_elements; i++) { + out_ptr[i] = in_ptr[i]; + } + break; + } + + case LLAISYS_DTYPE_I64: { + auto* in_ptr = reinterpret_cast(in_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < total_elements; i++) { + out_ptr[i] = in_ptr[i]; + } + break; + } + + default: + EXCEPTION_UNSUPPORTED_DATATYPE(static_cast(dtype)); + } + } } // namespace llaisys::ops diff --git a/src/ops/rms_norm/op.cpp b/src/ops/rms_norm/op.cpp index 529553d9d..9b215a795 100644 --- a/src/ops/rms_norm/op.cpp +++ b/src/ops/rms_norm/op.cpp @@ -1,7 +1,112 @@ #include "op.hpp" +#include namespace llaisys::ops { void rms_norm(tensor_t out, tensor_t in, tensor_t weight, float eps) { - TO_BE_IMPLEMENTED(); + //TO_BE_IMPLEMENTED(); + // 获取张量数据指针 + auto out_data = out->data(); + auto in_data = in->data(); + auto weight_data = weight->data(); + + // 获取维度信息 + auto in_shape = in->shape(); + size_t batch_size = in_shape[0]; + size_t d = in_shape[1]; + auto dtype = in->dtype(); + + // 根据数据类型进行处理 + switch (static_cast(dtype)) { + case LLAISYS_DTYPE_F32: { + auto* in_ptr = reinterpret_cast(in_data); + auto* weight_ptr = reinterpret_cast(weight_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < batch_size; i++) { + const float* row_in = in_ptr + i * d; + float* row_out = out_ptr + i * d; + + // 计算平方和 + float sum_sq = 0.0f; + for (size_t j = 0; j < d; j++) { + sum_sq += row_in[j] * row_in[j]; + } + + // 计算 RMS 归一化因子 + float rms = std::sqrt(sum_sq / static_cast(d) + eps); + float scale = 1.0f / rms; + + // 应用归一化和权重 + for (size_t j = 0; j < d; j++) { + row_out[j] = weight_ptr[j] * (row_in[j] * scale); + } + } + break; + } + + case LLAISYS_DTYPE_F16: { + auto* in_ptr = reinterpret_cast(in_data); + auto* weight_ptr = reinterpret_cast(weight_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < batch_size; i++) { + const llaisys::fp16_t* row_in = in_ptr + i * d; + llaisys::fp16_t* row_out = out_ptr + i * d; + + // 计算平方和(使用float精度) + float sum_sq = 0.0f; + for (size_t j = 0; j < d; j++) { + float val = llaisys::utils::_f16_to_f32(row_in[j]); + sum_sq += val * val; + } + + // 计算 RMS 归一化因子 + float rms = std::sqrt(sum_sq / static_cast(d) + eps); + float scale = 1.0f / rms; + + // 应用归一化和权重 + for (size_t j = 0; j < d; j++) { + float val = llaisys::utils::_f16_to_f32(row_in[j]); + float weight_val = llaisys::utils::_f16_to_f32(weight_ptr[j]); + row_out[j] = llaisys::utils::_f32_to_f16(weight_val * (val * scale)); + } + } + break; + } + + case LLAISYS_DTYPE_BF16: { + auto* in_ptr = reinterpret_cast(in_data); + auto* weight_ptr = reinterpret_cast(weight_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < batch_size; i++) { + const llaisys::bf16_t* row_in = in_ptr + i * d; + llaisys::bf16_t* row_out = out_ptr + i * d; + + // 计算平方和(使用float精度) + float sum_sq = 0.0f; + for (size_t j = 0; j < d; j++) { + float val = llaisys::utils::_bf16_to_f32(row_in[j]); + sum_sq += val * val; + } + + // 计算 RMS 归一化因子 + float rms = std::sqrt(sum_sq / static_cast(d) + eps); + float scale = 1.0f / rms; + + // 应用归一化和权重 + for (size_t j = 0; j < d; j++) { + float val = llaisys::utils::_bf16_to_f32(row_in[j]); + float weight_val = llaisys::utils::_bf16_to_f32(weight_ptr[j]); + row_out[j] = llaisys::utils::_f32_to_bf16(weight_val * (val * scale)); + } + } + break; + } + + default: + EXCEPTION_UNSUPPORTED_DATATYPE(static_cast(dtype)); + } + } } // namespace llaisys::ops diff --git a/src/ops/rope/op.cpp b/src/ops/rope/op.cpp index d60dbe64e..6461c428b 100644 --- a/src/ops/rope/op.cpp +++ b/src/ops/rope/op.cpp @@ -1,7 +1,160 @@ #include "op.hpp" +#include + namespace llaisys::ops { void rope(tensor_t out, tensor_t in, tensor_t pos_ids, float theta) { - TO_BE_IMPLEMENTED(); + //TO_BE_IMPLEMENTED(); + // 获取张量数据指针 + auto out_data = out->data(); + auto in_data = in->data(); + auto pos_ids_data = pos_ids->data(); + + // 获取维度信息 + auto in_shape = in->shape(); + size_t seqlen = in_shape[0]; + size_t nhead = in_shape[1]; + size_t d = in_shape[2]; + size_t d_half = d / 2; + auto dtype = in->dtype(); + + // 转换位置ID数据指针 + auto* pos_ids_ptr = reinterpret_cast(pos_ids_data); + + // 根据数据类型进行处理 + switch (static_cast(dtype)) { + case LLAISYS_DTYPE_F32: { + auto* in_ptr = reinterpret_cast(in_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t s = 0; s < seqlen; s++) { + int64_t pos = pos_ids_ptr[s]; + + for (size_t h = 0; h < nhead; h++) { + const float* head_in = in_ptr + (s * nhead + h) * d; + float* head_out = out_ptr + (s * nhead + h) * d; + + for (size_t j = 0; j < d_half; j++) { + // 计算角度 phi = pos / (theta^(2j/d)) + // float exponent = static_cast(2 * j) / static_cast(d); + // //float freq = 1.0f / std::pow(theta, exponent); + // float freq = std::exp(-exponent * std::log(theta)); + // float phi = static_cast(pos) * freq; + // float exponent = -static_cast(2 * j) / static_cast(d); + // float freq = std::pow(theta, exponent); // 直接计算 θ^(-2j/d) + // float phi = static_cast(pos) * freq; + float exponent = static_cast(2 * j) / static_cast(d); + float freq = std::pow(theta, exponent); // theta^(2j/d) + float phi = static_cast(pos) / freq; // pos / theta^(2j/d) + + // 计算 sin 和 cos + float cos_phi = std::cos(phi); + float sin_phi = std::sin(phi); + + // 获取 a_j 和 b_j + float a_j = head_in[j]; + float b_j = head_in[d_half + j]; + + // 计算旋转后的值 + head_out[j] = a_j * cos_phi - b_j * sin_phi; + head_out[d_half + j] = b_j * cos_phi + a_j * sin_phi; + } + } + } + break; + } + + case LLAISYS_DTYPE_F16: { + auto* in_ptr = reinterpret_cast(in_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t s = 0; s < seqlen; s++) { + int64_t pos = pos_ids_ptr[s]; + + for (size_t h = 0; h < nhead; h++) { + const llaisys::fp16_t* head_in = in_ptr + (s * nhead + h) * d; + llaisys::fp16_t* head_out = out_ptr + (s * nhead + h) * d; + + for (size_t j = 0; j < d_half; j++) { + // 计算角度 phi = pos / (theta^(2j/d)) + // float exponent = static_cast(2 * j) / static_cast(d); + // //float freq = 1.0f / std::pow(theta, exponent); + // float freq = std::exp(-exponent * std::log(theta)); + // float phi = static_cast(pos) * freq; + // float exponent = -static_cast(2 * j) / static_cast(d); + // float freq = std::pow(theta, exponent); // 直接计算 θ^(-2j/d) + // float phi = static_cast(pos) * freq; + float exponent = static_cast(2 * j) / static_cast(d); + float freq = std::pow(theta, exponent); // theta^(2j/d) + float phi = static_cast(pos) / freq; // pos / theta^(2j/d) + + // 计算 sin 和 cos + float cos_phi = std::cos(phi); + float sin_phi = std::sin(phi); + + // 获取 a_j 和 b_j + float a_j = llaisys::utils::_f16_to_f32(head_in[j]); + float b_j = llaisys::utils::_f16_to_f32(head_in[d_half + j]); + + // 计算旋转后的值 + float a_j_out = a_j * cos_phi - b_j * sin_phi; + float b_j_out = b_j * cos_phi + a_j * sin_phi; + + head_out[j] = llaisys::utils::_f32_to_f16(a_j_out); + head_out[d_half + j] = llaisys::utils::_f32_to_f16(b_j_out); + } + } + } + break; + } + + case LLAISYS_DTYPE_BF16: { + auto* in_ptr = reinterpret_cast(in_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t s = 0; s < seqlen; s++) { + int64_t pos = pos_ids_ptr[s]; + + for (size_t h = 0; h < nhead; h++) { + const llaisys::bf16_t* head_in = in_ptr + (s * nhead + h) * d; + llaisys::bf16_t* head_out = out_ptr + (s * nhead + h) * d; + + for (size_t j = 0; j < d_half; j++) { + // 计算角度 phi = pos / (theta^(2j/d)) + // float exponent = static_cast(2 * j) / static_cast(d); + // //float freq = 1.0f / std::pow(theta, exponent); + // float freq = std::exp(-exponent * std::log(theta)); + // float phi = static_cast(pos) * freq; + // float exponent = -static_cast(2 * j) / static_cast(d); + // float freq = std::pow(theta, exponent); // 直接计算 θ^(-2j/d) + // float phi = static_cast(pos) * freq; + float exponent = static_cast(2 * j) / static_cast(d); + float freq = std::pow(theta, exponent); // theta^(2j/d) + float phi = static_cast(pos) / freq; // pos / theta^(2j/d) + + // 计算 sin 和 cos + float cos_phi = std::cos(phi); + float sin_phi = std::sin(phi); + + // 获取 a_j 和 b_j + float a_j = llaisys::utils::_bf16_to_f32(head_in[j]); + float b_j = llaisys::utils::_bf16_to_f32(head_in[d_half + j]); + + // 计算旋转后的值 + float a_j_out = a_j * cos_phi - b_j * sin_phi; + float b_j_out = b_j * cos_phi + a_j * sin_phi; + + head_out[j] = llaisys::utils::_f32_to_bf16(a_j_out); + head_out[d_half + j] = llaisys::utils::_f32_to_bf16(b_j_out); + } + } + } + break; + } + + default: + EXCEPTION_UNSUPPORTED_DATATYPE(static_cast(dtype)); + } + } } // namespace llaisys::ops diff --git a/src/ops/self_attention/op.cpp b/src/ops/self_attention/op.cpp index 43d620142..29e2e4071 100644 --- a/src/ops/self_attention/op.cpp +++ b/src/ops/self_attention/op.cpp @@ -1,7 +1,493 @@ #include "op.hpp" +#include + namespace llaisys::ops { void self_attention(tensor_t attn_val, tensor_t q, tensor_t k, tensor_t v, float scale) { - TO_BE_IMPLEMENTED(); + //TO_BE_IMPLEMENTED(); + // 获取张量数据指针 + auto attn_val_data = attn_val->data(); + auto q_data = q->data(); + auto k_data = k->data(); + auto v_data = v->data(); + + // 获取维度信息 + auto q_shape = q->shape(); + auto k_shape = k->shape(); + auto v_shape = v->shape(); + + const int64_t seqlen = (int64_t)q_shape[0]; // L + const int64_t nhead = (int64_t)q_shape[1]; // nh + const int64_t d = (int64_t)q_shape[2]; // hd + const int64_t total_len = (int64_t)k_shape[0]; // S + const int64_t nkvhead = (int64_t)k_shape[1]; // nkvh + const int64_t dv = (int64_t)v_shape[2]; + + + + // Must match torch: + // key/value = repeat_interleave(head_group) along head dim => kv_head_idx = h / head_group + // Causal mask uses tril(diagonal=S-L) => allow tk <= sq + (S-L) + + const int64_t diag = total_len - seqlen; // diagonal = S - L + + // repeat_interleave(head_group) 的语义:kv_idx = h / head_group + // 测试用例保证整除,否则这里需要处理 + const int64_t head_group = (nkvhead == 0) ? 1 : (nhead / nkvhead); + + auto dtype = q->dtype(); + + // 根据数据类型进行处理 + switch (static_cast(dtype)) { + case LLAISYS_DTYPE_F32: { + auto* q_ptr = reinterpret_cast(q_data); + auto* k_ptr = reinterpret_cast(k_data); + auto* v_ptr = reinterpret_cast(v_data); + //auto* attn_val_ptr = reinterpret_cast(attn_val_data); + auto* out = reinterpret_cast(attn_val_data); + + // 为每个查询位置计算注意力 + //for (size_t sq = 0; sq < seqlen; sq++) { + for (int64_t sq = 0; sq < seqlen; ++sq) { + // torch: tril(diagonal=S-L) => allow tk <= sq + (S-L) + int64_t max_k = sq + diag; + int64_t valid_len_i64 = max_k + 1; + if (valid_len_i64 < 0) valid_len_i64 = 0; + if (valid_len_i64 > total_len) valid_len_i64 = total_len; + const int64_t valid_len = valid_len_i64; + //for (size_t h = 0; h < nhead; h++) { + for (int64_t h = 0; h < nhead; ++h) { + + // size_t kv_head_idx = h / head_group; + // // 计算对应的 kv 头索引(注意:这是 repeat_interleave,不是分组共享) + // // 在 repeat_interleave 中,KV头被重复了 head_group 次 + // // 所以索引应该是:h % nkvhead + // //size_t kv_head_idx = h % nkvhead; + const int64_t kv_head_idx = (head_group > 0) ? (h / head_group) : 0; + + + + if (seqlen == 5 && total_len == 11 && sq == 0 && h == 0) { + std::cout << "[DBG] nhead=" << nhead + << " nkvhead=" << nkvhead + << " head_group=" << head_group + << " diag=" << diag + << std::endl; +} +if (seqlen == 5 && total_len == 11 && sq == 0) { + std::cout << "[DBG] h=" << h + << " kv_head_idx=" << kv_head_idx + << std::endl; +} + +if (seqlen == 5 && total_len == 11 && sq == 0 && (h == 0 || h == 2)) { + std::cout << "[DBG] h=" << h << " kv=" << kv_head_idx << std::endl; +} + + + + + const float* q_head = q_ptr + (sq * nhead + h) * d; + //float* attn_head = attn_val_ptr + (sq * nhead + h) * dv; + float* out_head = out + (sq * nhead + h) * dv; + + + // init output + for (int64_t i = 0; i < dv; ++i) out_head[i] = 0.0f; + + if (valid_len == 0) { + // 与 torch 在极端形状下的行为可能不同(全 -inf softmax -> NaN) + // 测试用例不会触发;这里先直接返回 0 + continue; + } + + // scores buffer (float) + float* scores = new float[(size_t)valid_len]; + + + // // 初始化注意力值为0 + // for (size_t dv_idx = 0; dv_idx < dv; dv_idx++) { + // attn_head[dv_idx] = 0.0f; + // } + + // 计算注意力分数 + float max_score = -INFINITY; + // // 只考虑当前位置及之前的位置(因果注意力) + // // 注意:sq是查询位置,tk是键位置 + // //size_t valid_len = (sq < total_len) ? sq + 1 : total_len; + // size_t valid_len = std::min(sq + 1, total_len); + // //float* attn_scores = new float[total_len]; + // float* attn_scores = new float[valid_len]; + for (int64_t tk = 0; tk < valid_len; ++tk) { + const float* k_head = k_ptr + (tk * nkvhead + kv_head_idx) * d; + + float dot = 0.0f; + for (int64_t i = 0; i < d; ++i) dot += q_head[i] * k_head[i]; + + float s = dot * scale; + scores[tk] = s; + if (s > max_score) max_score = s; + } + + // // 计算 Q·K^T:查询头与对应KV头的点积 + // //for (size_t tk = 0; tk <= sq; tk++) { + // for (size_t tk = 0; tk < valid_len; tk++) { + // const float* k_head = k_ptr + (tk * nkvhead + kv_head_idx) * d; + + // float score = 0.0f; + // for (size_t idx = 0; idx < d; idx++) { + // score += q_head[idx] * k_head[idx]; + // } + // score *= scale; + + // // 应用因果掩码:未来位置设为负无穷 + // // 由于我们只循环到 valid_len = min(sq+1, total_len) + // // 所以不需要额外的掩码判断 + + // // if (tk > sq) { + // // score = -INFINITY; + // // //score = -10000.0;//-INFINITY; + // // } + // // 存储分数 + // attn_scores[tk] = score; + // if (score > max_score) { + // max_score = score; + // } + // } + + // // 计算 softmax(数值稳定版本) + // float exp_sum = 0.0f; + // //for (size_t tk = 0; tk <= sq; tk++) { + // for (size_t tk = 0; tk < valid_len; tk++) { + // // 对于因果注意力,未来位置已经在valid_len中排除了 + // // 所以这里直接计算softmax + // float exp_val = std::exp(attn_scores[tk] - max_score); + // exp_sum += exp_val; + // attn_scores[tk] = exp_val; + // } + + // // 归一化并加权求和 + // //for (size_t tk = 0; tk <= sq; tk++) { + // for (size_t tk = 0; tk < valid_len; tk++) { + // float weight = attn_scores[tk] / exp_sum; + // const float* v_head = v_ptr + (tk * nkvhead + kv_head_idx) * dv; + + // for (size_t dv_idx = 0; dv_idx < dv; dv_idx++) { + // attn_head[dv_idx] += weight * v_head[dv_idx]; + // } + // } + + float exp_sum = 0.0f; + for (int64_t tk = 0; tk < valid_len; ++tk) { + float e = std::exp(scores[tk] - max_score); + scores[tk] = e; + exp_sum += e; + } + + const float inv_sum = 1.0f / exp_sum; + for (int64_t tk = 0; tk < valid_len; ++tk) { + const float w = scores[tk] * inv_sum; + const float* v_head = v_ptr + (tk * nkvhead + kv_head_idx) * dv; + for (int64_t i = 0; i < dv; ++i) out_head[i] += w * v_head[i]; + } + + + //delete[] attn_scores; + delete[] scores; + } + } + break; + } + + case LLAISYS_DTYPE_F16: { + auto* q_ptr = reinterpret_cast(q_data); + auto* k_ptr = reinterpret_cast(k_data); + auto* v_ptr = reinterpret_cast(v_data); + //auto* attn_val_ptr = reinterpret_cast(attn_val_data); + auto* out = reinterpret_cast(attn_val_data); + + //for (size_t sq = 0; sq < seqlen; sq++) { + for (int64_t sq = 0; sq < seqlen; ++sq) { + int64_t max_k = sq + diag; + int64_t valid_len_i64 = max_k + 1; + if (valid_len_i64 < 0) valid_len_i64 = 0; + if (valid_len_i64 > total_len) valid_len_i64 = total_len; + const int64_t valid_len = valid_len_i64; + + + + + //for (size_t h = 0; h < nhead; h++) { + for (int64_t h = 0; h < nhead; ++h) { + // //size_t kv_head_idx = h / head_group; + // // 使用取模运算,匹配 repeat_interleave + // size_t kv_head_idx = h % nkvhead; + const int64_t kv_head_idx = (head_group > 0) ? (h / head_group) : 0; + + const llaisys::fp16_t* q_head = q_ptr + (sq * nhead + h) * d; + //llaisys::fp16_t* attn_head = attn_val_ptr + (sq * nhead + h) * dv; + llaisys::fp16_t* out_head = out + (sq * nhead + h) * dv; + + // // 初始化注意力值为0 + // for (size_t dv_idx = 0; dv_idx < dv; dv_idx++) { + // //attn_head[dv_idx] = llaisys::fp16_t(0.0f); + // attn_head[dv_idx] = llaisys::utils::_f32_to_f16(0.0f); + // } + + // float accumulator to reduce quantization error + float* acc = new float[(size_t)dv]; + for (int64_t i = 0; i < dv; ++i) acc[i] = 0.0f; + + if (valid_len == 0) { + for (int64_t i = 0; i < dv; ++i) out_head[i] = llaisys::utils::_f32_to_f16(0.0f); + delete[] acc; + continue; + } + + float* scores = new float[(size_t)valid_len]; + + + + float max_score = -INFINITY; + // //size_t valid_len = (sq < total_len) ? sq + 1 : total_len; + // size_t valid_len = std::min(sq + 1, total_len); + + // //float* attn_scores = new float[total_len]; + // float* attn_scores = new float[valid_len]; + + // // 计算 Q·K^T + // //for (size_t tk = 0; tk <= sq; tk++) { + // for (size_t tk = 0; tk < valid_len; tk++) { + // const llaisys::fp16_t* k_head = k_ptr + (tk * nkvhead + kv_head_idx) * d; + + // float score = 0.0f; + // for (size_t idx = 0; idx < d; idx++) { + // float q_val = llaisys::utils::_f16_to_f32(q_head[idx]); + // float k_val = llaisys::utils::_f16_to_f32(k_head[idx]); + // score += q_val * k_val; + // } + // score *= scale; + + // // if (tk > sq) { + // // score = -INFINITY; + // // } + + // // 存储分数 + // attn_scores[tk] = score; + // if (score > max_score) { + // max_score = score; + // } + // } + + // // 计算 softmax + // float exp_sum = 0.0f; + // //for (size_t tk = 0; tk <= sq; tk++) { + // for (size_t tk = 0; tk < valid_len; tk++) { + // // 对于因果注意力,未来位置已经在valid_len中排除了 + // // 所以这里直接计算softmax + // float exp_val = std::exp(attn_scores[tk] - max_score); + // exp_sum += exp_val; + // attn_scores[tk] = exp_val; + // } + + // // 加权求和 + // //for (size_t tk = 0; tk <= sq; tk++) { + // for (size_t tk = 0; tk < valid_len; tk++) { + // float weight = attn_scores[tk] / exp_sum; + // const llaisys::fp16_t* v_head = v_ptr + (tk * nkvhead + kv_head_idx) * dv; + + // for (size_t dv_idx = 0; dv_idx < dv; dv_idx++) { + // float attn_val_f = llaisys::utils::_f16_to_f32(attn_head[dv_idx]); + // float v_val = llaisys::utils::_f16_to_f32(v_head[dv_idx]); + // attn_val_f += weight * v_val; + // attn_head[dv_idx] = llaisys::utils::_f32_to_f16(attn_val_f); + // } + // } + + // delete[] attn_scores; + max_score = -INFINITY; + for (int64_t tk = 0; tk < valid_len; ++tk) { + const llaisys::fp16_t* k_head = k_ptr + (tk * nkvhead + kv_head_idx) * d; + + float dot = 0.0f; + for (int64_t i = 0; i < d; ++i) { + float qv = llaisys::utils::_f16_to_f32(q_head[i]); + float kv = llaisys::utils::_f16_to_f32(k_head[i]); + dot += qv * kv; + } + + float s = dot * scale; + scores[tk] = s; + if (s > max_score) max_score = s; + } + + float exp_sum = 0.0f; + for (int64_t tk = 0; tk < valid_len; ++tk) { + float e = std::exp(scores[tk] - max_score); + scores[tk] = e; + exp_sum += e; + } + + const float inv_sum = 1.0f / exp_sum; + for (int64_t tk = 0; tk < valid_len; ++tk) { + const float w = scores[tk] * inv_sum; + const llaisys::fp16_t* v_head = v_ptr + (tk * nkvhead + kv_head_idx) * dv; + for (int64_t i = 0; i < dv; ++i) { + float vv = llaisys::utils::_f16_to_f32(v_head[i]); + acc[i] += w * vv; + } + } + + for (int64_t i = 0; i < dv; ++i) out_head[i] = llaisys::utils::_f32_to_f16(acc[i]); + + delete[] scores; + delete[] acc; + } + } + break; + } + + case LLAISYS_DTYPE_BF16: { + auto* q_ptr = reinterpret_cast(q_data); + auto* k_ptr = reinterpret_cast(k_data); + auto* v_ptr = reinterpret_cast(v_data); + //auto* attn_val_ptr = reinterpret_cast(attn_val_data); + auto* out = reinterpret_cast(attn_val_data); + + //for (size_t sq = 0; sq < seqlen; sq++) { + for (int64_t sq = 0; sq < seqlen; ++sq) { + int64_t max_k = sq + diag; + int64_t valid_len_i64 = max_k + 1; + if (valid_len_i64 < 0) valid_len_i64 = 0; + if (valid_len_i64 > total_len) valid_len_i64 = total_len; + const int64_t valid_len = valid_len_i64; + + //for (size_t h = 0; h < nhead; h++) { + for (int64_t h = 0; h < nhead; ++h) { + // //size_t kv_head_idx = h / head_group; + // // 使用取模运算,匹配 repeat_interleave + // size_t kv_head_idx = h % nkvhead; + const int64_t kv_head_idx = (head_group > 0) ? (h / head_group) : 0; + + const llaisys::bf16_t* q_head = q_ptr + (sq * nhead + h) * d; + //llaisys::bf16_t* attn_head = attn_val_ptr + (sq * nhead + h) * dv; + llaisys::bf16_t* out_head = out + (sq * nhead + h) * dv; + + // // 初始化注意力值为0 + // for (size_t dv_idx = 0; dv_idx < dv; dv_idx++) { + // //attn_head[dv_idx] = llaisys::bf16_t(0.0f); + // attn_head[dv_idx] = llaisys::utils::_f32_to_bf16(0.0f); + // } + + float* acc = new float[(size_t)dv]; + for (int64_t i = 0; i < dv; ++i) acc[i] = 0.0f; + + if (valid_len == 0) { + for (int64_t i = 0; i < dv; ++i) out_head[i] = llaisys::utils::_f32_to_bf16(0.0f); + delete[] acc; + continue; + } + + float* scores = new float[(size_t)valid_len]; + + float max_score = -INFINITY; + // //size_t valid_len = (sq < total_len) ? sq + 1 : total_len; + // size_t valid_len = std::min(sq + 1, total_len); + // //float* attn_scores = new float[total_len]; + // float* attn_scores = new float[valid_len]; + + // // 计算 Q·K^T + // //for (size_t tk = 0; tk <= sq; tk++) { + // for (size_t tk = 0; tk < valid_len; tk++) { + // const llaisys::bf16_t* k_head = k_ptr + (tk * nkvhead + kv_head_idx) * d; + + // float score = 0.0f; + // for (size_t idx = 0; idx < d; idx++) { + // float q_val = llaisys::utils::_bf16_to_f32(q_head[idx]); + // float k_val = llaisys::utils::_bf16_to_f32(k_head[idx]); + // score += q_val * k_val; + // } + // score *= scale; + + // // if (tk > sq) { + // // score = -INFINITY; + // // } + + // attn_scores[tk] = score; + // if (score > max_score) { + // max_score = score; + // } + // } + + // // 计算 softmax + // float exp_sum = 0.0f; + // //for (size_t tk = 0; tk <= sq; tk++) { + // for (size_t tk = 0; tk < valid_len; tk++) { + // float exp_val = std::exp(attn_scores[tk] - max_score); + // exp_sum += exp_val; + // attn_scores[tk] = exp_val; + // } + + // // 加权求和 + // //for (size_t tk = 0; tk <= sq; tk++) { + // for (size_t tk = 0; tk < valid_len; tk++) { + // float weight = attn_scores[tk] / exp_sum; + // const llaisys::bf16_t* v_head = v_ptr + (tk * nkvhead + kv_head_idx) * dv; + + // for (size_t dv_idx = 0; dv_idx < dv; dv_idx++) { + // float attn_val_f = llaisys::utils::_bf16_to_f32(attn_head[dv_idx]); + // float v_val = llaisys::utils::_bf16_to_f32(v_head[dv_idx]); + // attn_val_f += weight * v_val; + // attn_head[dv_idx] = llaisys::utils::_f32_to_bf16(attn_val_f); + // } + // } + + // delete[] attn_scores; + // } + for (int64_t tk = 0; tk < valid_len; ++tk) { + const llaisys::bf16_t* k_head = k_ptr + (tk * nkvhead + kv_head_idx) * d; + + float dot = 0.0f; + for (int64_t i = 0; i < d; ++i) { + float qv = llaisys::utils::_bf16_to_f32(q_head[i]); + float kv = llaisys::utils::_bf16_to_f32(k_head[i]); + dot += qv * kv; + } + + float s = dot * scale; + scores[tk] = s; + if (s > max_score) max_score = s; + } + + float exp_sum = 0.0f; + for (int64_t tk = 0; tk < valid_len; ++tk) { + float e = std::exp(scores[tk] - max_score); + scores[tk] = e; + exp_sum += e; + } + + const float inv_sum = 1.0f / exp_sum; + for (int64_t tk = 0; tk < valid_len; ++tk) { + const float w = scores[tk] * inv_sum; + const llaisys::bf16_t* v_head = v_ptr + (tk * nkvhead + kv_head_idx) * dv; + for (int64_t i = 0; i < dv; ++i) { + float vv = llaisys::utils::_bf16_to_f32(v_head[i]); + acc[i] += w * vv; + } + } + + for (int64_t i = 0; i < dv; ++i) out_head[i] = llaisys::utils::_f32_to_bf16(acc[i]); + + delete[] scores; + delete[] acc; + } + } + break; + } + + default: + EXCEPTION_UNSUPPORTED_DATATYPE(static_cast(dtype)); + } + } } // namespace llaisys::ops diff --git a/src/ops/swiglu/op.cpp b/src/ops/swiglu/op.cpp index 47edbcc97..66f85ab86 100644 --- a/src/ops/swiglu/op.cpp +++ b/src/ops/swiglu/op.cpp @@ -1,7 +1,105 @@ #include "op.hpp" +#include + namespace llaisys::ops { void swiglu(tensor_t out, tensor_t gate, tensor_t up) { - TO_BE_IMPLEMENTED(); + //TO_BE_IMPLEMENTED(); + // 获取张量数据指针 + auto out_data = out->data(); + auto gate_data = gate->data(); + auto up_data = up->data(); + + // 获取维度信息 + auto gate_shape = gate->shape(); + size_t seqlen = gate_shape[0]; + size_t intermediate_size = gate_shape[1]; + auto dtype = gate->dtype(); + + // 根据数据类型进行处理 + switch (static_cast(dtype)) { + case LLAISYS_DTYPE_F32: { + auto* gate_ptr = reinterpret_cast(gate_data); + auto* up_ptr = reinterpret_cast(up_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < seqlen; i++) { + const float* gate_row = gate_ptr + i * intermediate_size; + const float* up_row = up_ptr + i * intermediate_size; + float* out_row = out_ptr + i * intermediate_size; + + for (size_t j = 0; j < intermediate_size; j++) { + // // SwiGLU: out = up * sigmoid(gate) + // float gate_val = gate_row[j]; + // float sigmoid_gate = 1.0f / (1.0f + std::exp(-gate_val)); + // out_row[j] = up_row[j] * sigmoid_gate; + // 正确的 SwiGLU 公式:out = up * SiLU(gate) = up * (gate * sigmoid(gate)) + float gate_val = gate_row[j]; + float sigmoid_gate = 1.0f / (1.0f + std::exp(-gate_val)); + float silu_gate = gate_val * sigmoid_gate; // SiLU 激活函数 + out_row[j] = up_row[j] * silu_gate; + } + } + break; + } + + case LLAISYS_DTYPE_F16: { + auto* gate_ptr = reinterpret_cast(gate_data); + auto* up_ptr = reinterpret_cast(up_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < seqlen; i++) { + const llaisys::fp16_t* gate_row = gate_ptr + i * intermediate_size; + const llaisys::fp16_t* up_row = up_ptr + i * intermediate_size; + llaisys::fp16_t* out_row = out_ptr + i * intermediate_size; + + for (size_t j = 0; j < intermediate_size; j++) { + // float gate_val = llaisys::utils::_f16_to_f32(gate_row[j]); + // float sigmoid_gate = 1.0f / (1.0f + std::exp(-gate_val)); + // float up_val = llaisys::utils::_f16_to_f32(up_row[j]); + // float out_val = up_val * sigmoid_gate; + // out_row[j] = llaisys::utils::_f32_to_f16(out_val); + float gate_val = llaisys::utils::_f16_to_f32(gate_row[j]); + float sigmoid_gate = 1.0f / (1.0f + std::exp(-gate_val)); + float silu_gate = gate_val * sigmoid_gate; // SiLU 激活函数 + float up_val = llaisys::utils::_f16_to_f32(up_row[j]); + float out_val = up_val * silu_gate; + out_row[j] = llaisys::utils::_f32_to_f16(out_val); + } + } + break; + } + + case LLAISYS_DTYPE_BF16: { + auto* gate_ptr = reinterpret_cast(gate_data); + auto* up_ptr = reinterpret_cast(up_data); + auto* out_ptr = reinterpret_cast(out_data); + + for (size_t i = 0; i < seqlen; i++) { + const llaisys::bf16_t* gate_row = gate_ptr + i * intermediate_size; + const llaisys::bf16_t* up_row = up_ptr + i * intermediate_size; + llaisys::bf16_t* out_row = out_ptr + i * intermediate_size; + + for (size_t j = 0; j < intermediate_size; j++) { + // float gate_val = llaisys::utils::_bf16_to_f32(gate_row[j]); + // float sigmoid_gate = 1.0f / (1.0f + std::exp(-gate_val)); + // float up_val = llaisys::utils::_bf16_to_f32(up_row[j]); + // float out_val = up_val * sigmoid_gate; + // out_row[j] = llaisys::utils::_f32_to_bf16(out_val); + float gate_val = llaisys::utils::_bf16_to_f32(gate_row[j]); + float sigmoid_gate = 1.0f / (1.0f + std::exp(-gate_val)); + float silu_gate = gate_val * sigmoid_gate; // SiLU 激活函数 + float up_val = llaisys::utils::_bf16_to_f32(up_row[j]); + float out_val = up_val * silu_gate; + out_row[j] = llaisys::utils::_f32_to_bf16(out_val); + } + } + break; + } + + default: + EXCEPTION_UNSUPPORTED_DATATYPE(static_cast(dtype)); + } + } } // namespace llaisys::ops diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index 2f594bb65..c56884fae 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -164,37 +164,137 @@ void Tensor::debug() const { } bool Tensor::isContiguous() const { - TO_BE_IMPLEMENTED(); + + // 检查 tensor 是否连续存储 + // 连续存储意味着相邻元素在内存中也是相邻的 + const auto &shape = _meta.shape; + const auto &strides = _meta.strides; + + // 空 tensor 视为连续 + if (shape.empty()) { + return true; + } + + // 从最后一维开始,期望的 stride 初始为 1 + ptrdiff_t expected = 1; + for (int i = static_cast(shape.size()) - 1; i >= 0; --i) { + // 跳过 size=1 的维度,因为它们不影响连续性 + if (shape[i] != 1) { + // 检查当前维度的 stride 是否等于期望值 + if (strides[i] != expected) { + return false; + } + // 更新下一维的期望 stride + expected *= static_cast(shape[i]); + } + } return true; } tensor_t Tensor::permute(const std::vector &order) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + // return std::shared_ptr(new Tensor(_meta, _storage)); + const size_t ndims = _meta.shape.size(); + + CHECK_ARGUMENT(order.size() == ndims, "permute: order size mismatch"); + + // valid order + std::vector seen(ndims, false); + for (size_t idx:order) { + CHECK_ARGUMENT(idx < ndims, "permute: index out of range"); + CHECK_ARGUMENT(!seen[idx], "permute: duplicate index in order"); + seen[idx] = true; + } + + // build new meta + TensorMeta new_meta; + new_meta.dtype = _meta.dtype; + new_meta.shape.resize(ndims); + new_meta.strides.resize(ndims); + + for (size_t i = 0;i < ndims; ++i) { + new_meta.shape[i] = _meta.shape[order[i]]; + new_meta.strides[i] = _meta.strides[order[i]]; + } + + return std::shared_ptr(new Tensor(new_meta, _storage, _offset)); + } tensor_t Tensor::view(const std::vector &shape) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + // return std::shared_ptr(new Tensor(_meta, _storage)); + // calculate total of new elements + size_t new_numel = 1; + for (size_t s:shape) { + new_numel *= s; + } + + CHECK_ARGUMENT(new_numel == numel(), "view: element count mismatch"); + CHECK_ARGUMENT(isContiguous(), "view: tensor must be contiguous"); + + // build new meta + TensorMeta new_meta; + new_meta.dtype = _meta.dtype; + new_meta.shape = shape; + + // calculate strides (Tensor::create) + size_t ndim_ = shape.size(); + new_meta.strides.resize(ndim_); + size_t stride = 1; + for (size_t i = 1; i <= ndim_; i++) { + new_meta.strides[ndim_ - i] = stride; + stride *= shape[ndim_ - i]; + } + + return std::shared_ptr(new Tensor(new_meta, _storage, _offset)); } tensor_t Tensor::slice(size_t dim, size_t start, size_t end) const { - TO_BE_IMPLEMENTED(); - return std::shared_ptr(new Tensor(_meta, _storage)); + // return std::shared_ptr(new Tensor(_meta, _storage)); + // valid dim + CHECK_ARGUMENT(dim < _meta.shape.size(), "slice: dim out of range"); + CHECK_ARGUMENT(start < end, "slice: start must be less than end"); + CHECK_ARGUMENT(end <= _meta.shape[dim], "slice: end exceeds dimension size"); + + // copy meta and revise shape + TensorMeta new_meta = _meta; + new_meta.shape[dim] = end - start; + + // calculate the new offset + // start * stride[dim] * each_elements_bytes + size_t new_offset = _offset + + start * static_cast(_meta.strides[dim])* elementSize(); + + return std::shared_ptr(new Tensor(new_meta, _storage, new_offset)); } void Tensor::load(const void *src_) { - TO_BE_IMPLEMENTED(); + //std::cerr << "[DBG] enter Tensor::load, src=" << src_ << std::endl; + CHECK_ARGUMENT(src_ != nullptr, "load: src is null"); + //std::cerr << "[DBG] after CHECK_ARGUMENT" << std::endl; + + void *dst = data(); + size_t size = numel() * elementSize(); + // std::cout << "memcpy_sync: " << size << " bytes" << std::endl; + // std::cout << "memcpy_sync: " << src_ << " -> " << dst << std::endl; + // std::cout << "memcpy_sync: " << deviceType() << std::endl; + device::getRuntimeAPI(deviceType())->memcpy_sync(dst, src_, size, LLAISYS_MEMCPY_H2D); + auto api = device::getRuntimeAPI(deviceType()); + //std::cerr << "[DBG] api=" << api << " deviceType=" << deviceType() << std::endl; + api->memcpy_sync(dst, src_, size, LLAISYS_MEMCPY_H2D); + core::context().runtime().api()->device_synchronize(); } tensor_t Tensor::contiguous() const { TO_BE_IMPLEMENTED(); return std::shared_ptr(new Tensor(_meta, _storage)); + } tensor_t Tensor::reshape(const std::vector &shape) const { TO_BE_IMPLEMENTED(); return std::shared_ptr(new Tensor(_meta, _storage)); + + // return contiguous()->view(shape); } tensor_t Tensor::to(llaisysDeviceType_t device_type, int device) const { @@ -202,4 +302,4 @@ tensor_t Tensor::to(llaisysDeviceType_t device_type, int device) const { return std::shared_ptr(new Tensor(_meta, _storage)); } -} // namespace llaisys +} // namespace llaisys \ No newline at end of file diff --git a/test/ops/argmax.py b/test/ops/argmax.py index d0f7ee298..f1fc26fdd 100644 --- a/test/ops/argmax.py +++ b/test/ops/argmax.py @@ -41,6 +41,10 @@ def test_op_argmax( if __name__ == "__main__": import argparse + print(llaisys.__file__) + + print("llaisys pkg:", os.path.dirname(llaisys.__file__)) + #print("so candidates:", glob.glob(os.path.join(os.path.dirname(llaisys.__file__), "**/*.so"), recursive=True)) parser = argparse.ArgumentParser() parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) diff --git a/test/ops/rope.py b/test/ops/rope.py index fe59dd11c..74c70da69 100644 --- a/test/ops/rope.py +++ b/test/ops/rope.py @@ -7,6 +7,22 @@ import torch from test_utils import arrange_tensor, random_tensor, check_equal, benchmark +# def check_equal(y_, y, atol=1e-5, rtol=1e-5): +# # 确保两个张量都是 PyTorch 张量 +# if not isinstance(y_, torch.Tensor): +# y_ = torch.tensor(y_) +# if not isinstance(y, torch.Tensor): +# y = torch.tensor(y) + +# # 确保它们在同一个设备上 +# if y_.device != y.device: +# y = y.to(y_.device) + +# diff = torch.abs(y_ - y) +# max_diff = diff.max() +# mean_diff = diff.mean() +# print(f"Max diff: {max_diff}, Mean diff: {mean_diff}") + def torch_rope(y: torch.Tensor, x: torch.Tensor, pos_ids: torch.Tensor, theta: float): assert y.dim() == 3 @@ -49,6 +65,8 @@ def test_op_rope( torch_rope(y, x, pos_ids, theta) llaisys.Ops.rope(y_, x_, pos_ids_, theta) + #check_equal(y_, y, atol=atol, rtol=rtol) + assert check_equal(y_, y, atol=atol, rtol=rtol) if profile: diff --git a/test/ops/self_attention.py b/test/ops/self_attention.py index a042b51be..1bb2b1295 100644 --- a/test/ops/self_attention.py +++ b/test/ops/self_attention.py @@ -8,25 +8,69 @@ from test_utils import random_tensor, check_equal, benchmark +# def torch_self_attention(attn_val, query, key, value, scale): +# query = query.transpose(-2, -3) +# key = key.transpose(-2, -3) +# value = value.transpose(-2, -3) +# L, S = query.size(-2), key.size(-2) +# attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + +# temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=S-L) +# attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) +# attn_bias.to(query.dtype) + +# key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) +# value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + +# attn_weight = query @ key.transpose(-2, -1) * scale +# attn_weight += attn_bias +# attn_weight = torch.softmax(attn_weight, dim=-1) +# attn_val.copy_((attn_weight @ value).transpose(-2, -3)) + def torch_self_attention(attn_val, query, key, value, scale): - query = query.transpose(-2, -3) - key = key.transpose(-2, -3) - value = value.transpose(-2, -3) - L, S = query.size(-2), key.size(-2) - attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + print("=== PyTorch 调试 ===") + print(f"query shape: {query.shape}") + print(f"key shape: {key.shape}") + print(f"value shape: {value.shape}") + print(f"scale: {scale}") + + q_transposed = query.transpose(-2, -3) # [nh, qlen, d] + k_transposed = key.transpose(-2, -3) # [nkvh, kvlen, d] + v_transposed = value.transpose(-2, -3) # [nkvh, kvlen, dv] + print(f"q_transposed shape: {q_transposed.shape}") + print(f"k_transposed shape: {k_transposed.shape}") + + k_repeated = k_transposed.repeat_interleave(q_transposed.size(-3) // k_transposed.size(-3), -3) # [nh, kvlen, d] + v_repeated = v_transposed.repeat_interleave(q_transposed.size(-3) // v_transposed.size(-3), -3) # [nh, kvlen, dv] + + print(f"k_repeated shape: {k_repeated.shape}") + + # 注意力计算: + attn_weight = q_transposed @ k_repeated.transpose(-2, -1) * scale # [nh, qlen, kvlen] + print(f"attn_weight shape: {attn_weight.shape}") + print(f"attn_weight[0,0,:]: {attn_weight[0,0,:]}") + + # 应用因果掩码 + L, S = q_transposed.size(-2), k_repeated.size(-2) + print(f"L={L}, S={S}") + attn_bias = torch.zeros(L, S, dtype=q_transposed.dtype, device=q_transposed.device) temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=S-L) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - - key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) - value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) - - attn_weight = query @ key.transpose(-2, -1) * scale attn_weight += attn_bias - attn_weight = torch.softmax(attn_weight, dim=-1) - attn_val.copy_((attn_weight @ value).transpose(-2, -3)) - + + print(f"attn_weight after mask[0,0,:]: {attn_weight[0,0,:]}") + + # softmax和加权求和 + attn_weight = torch.softmax(attn_weight, dim=-1) # [nh, qlen, kvlen] + print(f"attn_weight after softmax[0,0,:]: {attn_weight[0,0,:]}") + + result = (attn_weight @ v_repeated).transpose(-2, -3) # [qlen, nh, dv] + print(f"result shape: {result.shape}") + print(f"result[0,0,:]: {result[0,0,:]}") + + attn_val.copy_(result) + print("=== 调试结束 ===") def test_op_self_attention( qlen, @@ -68,6 +112,33 @@ def test_op_self_attention( parser.add_argument("--device", default="cpu", choices=["cpu", "nvidia"], type=str) parser.add_argument("--profile", action="store_true") args = parser.parse_args() + + # 原始维度: + # query: [qlen, nh, d] + # key: [kvlen, nkvh, d] + # value: [kvlen, nkvh, dv] + + # testShapes = [ + # # qlen, kvlen, nh, nkvh, hd + # (1, 1, 2, 1, 2), + # ] + # testDtypePrec = [ + # # type, atol, rtol + # ("f32", 1e-5, 1e-5), + # ] + + + # query = query.transpose(-2, -3) # [nh, qlen, d] + # key = key.transpose(-2, -3) # [nkvh, kvlen, d] + # value = value.transpose(-2, -3) # [nkvh, kvlen, dv] + + # # repeat_interleave 后: + # key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) # [nh, kvlen, d] + # value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) # [nh, kvlen, dv] + + # # 注意力计算: + # attn_weight = query @ key.transpose(-2, -1) * scale # [nh, qlen, d] @ [nh, d, kvlen] = [nh, qlen, kvlen] + testShapes = [ # qlen, kvlen, nh, nkvh, hd (2, 2, 1, 1, 4), @@ -79,6 +150,7 @@ def test_op_self_attention( ("f16", 1e-3, 1e-3), ("bf16", 1e-2, 1e-2), ] + print(f"Testing Ops.self_attention on {args.device}") for shape in testShapes: for dtype_name, atol, rtol in testDtypePrec: diff --git a/test/test_tensor.py b/test/test_tensor.py index 9d2e9a075..380e11b56 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -4,19 +4,35 @@ from test_utils import * import argparse - +import llaisys, inspect, os +import llaisys, os, glob, inspect def test_tensor(): + print(llaisys.__file__) + + print("llaisys pkg:", os.path.dirname(llaisys.__file__)) + print("so candidates:", glob.glob(os.path.join(os.path.dirname(llaisys.__file__), "**/*.so"), recursive=True)) + + #print("enter funciton1\n") torch_tensor = torch.arange(60, dtype=torch_dtype("i64")).reshape(3, 4, 5) + #print("enter funciton2\n") llaisys_tensor = llaisys.Tensor( (3, 4, 5), dtype=llaisys_dtype("i64"), device=llaisys_device("cpu") ) + #print("enter funciton3\n") # Test load print("===Test load===") + #print(f"enter funciton4 {torch_tensor.data_ptr()}\n") + + llaisys_tensor.load(torch_tensor.data_ptr()) + #print("enter funciton4\n") llaisys_tensor.debug() + #print("enter funciton5\n") assert llaisys_tensor.is_contiguous() == torch_tensor.is_contiguous() + #print("enter funciton6\n") assert check_equal(llaisys_tensor, torch_tensor) + #print("enter funciton7\n") # Test view print("===Test view===") diff --git a/xmake.lua b/xmake.lua index 1f65f7a95..43f00c570 100644 --- a/xmake.lua +++ b/xmake.lua @@ -106,6 +106,7 @@ target("llaisys") set_languages("cxx17") set_warnings("all", "error") add_files("src/llaisys/*.cc") + add_files("src/llaisys/models/*.cc") set_installdir(".")