From b4459dada3bd288fd108eac15b678e1ab372abf8 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Mon, 1 Jun 2026 13:40:19 +0800 Subject: [PATCH 01/10] Add Metal GPU + CoreML/ANE transformer support for b10c384h6nbttflrs (v15) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement the LLaMA-style transformer-hybrid forward pass (RMSNorm, multi-head attention with learnable 2D RoPE, SwiGLU FFN) plus ACTIVATION_SILU across the Metal GPU (MPSGraph) and CoreML/ANE (MIL) backends, so the v15 b10c384h6nbttflrs model runs end-to-end. Metal GPU (MPSGraph) — verified via testgpuerror vs Eigen reference at sizes 9/13/19 (winrate error ~0.0001%, well under threshold): - metallayers.swift: TransformerRMSNormLayer, TrunkRMSNormLayer, TransformerAttentionBlock, TransformerFFNBlock, silu() activation, SWTransformer*/SWRMSNorm descriptors; Trunk branches on trunkNormKind - metalbackend.cpp: SILU bridge + transformer/RMSNorm desc bridges, wired into residualBlocksToSwift and trunkDescToSwift CoreML/ANE (katagocoreml MIL) — implemented end-to-end; fp32 model logically correct and consistent across CPU/ANE/GPU. fp16 ANE path is numerically precision-limited (~5%) due to fp16 matmul accumulation in the deep attention stack: - types/parser: ActivationType::Silu, trunk_norm_kind, transformer block kinds 4/5, RMSNorm/attention/FFN descriptors - MILBuilder: addSiluOps, RMSNorm ops, transformer attention/FFN blocks. Fixes 4 CoreML bugs: reshape-after-transpose, fp16 mask overflow, fp16 RMSNorm reduce_sum overflow (reduce_mean) Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 595 +++++++++++++++++- .../katagocoreml/src/builder/MILBuilder.hpp | 38 ++ .../katagocoreml/src/parser/KataGoParser.cpp | 136 +++- .../katagocoreml/src/parser/KataGoParser.hpp | 4 + .../katagocoreml/src/types/KataGoTypes.hpp | 68 +- cpp/neuralnet/metalbackend.cpp | 79 ++- cpp/neuralnet/metallayers.swift | 578 ++++++++++++++++- 7 files changed, 1467 insertions(+), 31 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index db0c6c4b1..2a0bbf44a 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -4,6 +4,7 @@ #include "MILBuilder.hpp" #include "MILBlob/Fp16.hpp" #include +#include // Include generated protobuf headers #include "MIL.pb.h" @@ -732,6 +733,63 @@ void MILBuilder::addBatchNormActivationOps(CoreML::Specification::MILSpec::Block setTensorOutput4D(op, output, bn.num_channels, m_board_y_size, m_board_x_size); } else if (act.activation_type == ActivationType::Mish) { addMishOps(block, bn_output, output, 4, bn.num_channels); + } else if (act.activation_type == ActivationType::Silu) { + addSiluOps(block, bn_output, output, 4, bn.num_channels); + } +} + +void MILBuilder::addSiluOps(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& output, + int rank, + int channels) { + // SiLU / Swish: x * sigmoid(x) + auto setOutputType = [this, rank, channels](CoreML::Specification::MILSpec::Operation* op, const std::string& name) { + auto* out = op->add_outputs(); + out->set_name(name); + auto* out_type = out->mutable_type()->mutable_tensortype(); + out_type->set_datatype(m_weight_dtype); + out_type->set_rank(rank); + setBatchDimension(out_type); + out_type->add_dimensions()->mutable_constant()->set_size(channels); + if (rank == 4) { + out_type->add_dimensions()->mutable_constant()->set_size(m_board_y_size); + out_type->add_dimensions()->mutable_constant()->set_size(m_board_x_size); + } + }; + + std::string sig = output + "_sigmoid"; + { + auto* op = block->add_operations(); + op->set_type("sigmoid"); + auto& inputs = *op->mutable_inputs(); + inputs["x"].add_arguments()->set_name(input); + setOutputType(op, sig); + } + { + auto* op = block->add_operations(); + op->set_type("mul"); + auto& inputs = *op->mutable_inputs(); + inputs["x"].add_arguments()->set_name(input); + inputs["y"].add_arguments()->set_name(sig); + setOutputType(op, output); + } +} + +void MILBuilder::setShape(CoreML::Specification::MILSpec::Operation* op, + const std::string& name, + const std::vector& dims) { + auto* out = op->add_outputs(); + out->set_name(name); + auto* t = out->mutable_type()->mutable_tensortype(); + t->set_datatype(m_weight_dtype); + t->set_rank(static_cast(dims.size())); + for (int64_t d : dims) { + auto* dim = t->add_dimensions(); + if (d < 0) + dim->mutable_unknown()->set_variadic(false); + else + dim->mutable_constant()->set_size(d); } } @@ -1637,6 +1695,522 @@ void MILBuilder::addGlobalPoolingValueOps(CoreML::Specification::MILSpec::Block* // Network Component Builders // ============================================================================ +// --------------------------------------------------------------------------- +// Transformer blocks (MIL). Layout is NCHW [B, C, H, W]; spatial positions +// (H*W, ordered y*W+x) are treated as the attention sequence. RoPE is applied +// via a fixed pair-rotation matmul plus host-precomputed cos/sin tables, which +// keeps every tensor rank <= 4 (ANE-friendly). +// --------------------------------------------------------------------------- + +std::string MILBuilder::addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerRMSNormDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int H = m_board_y_size, W = m_board_x_size; + auto emit2 = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string sq = genVarName(prefix + "_sq"); + emit2("mul", input, input, sq, {-1, C, H, W}); + // meanSq = reduce_mean(sq, axes=[1]) over channels. reduce_mean (not reduce_sum) is used so + // the accumulator stays ~O(activation^2) instead of summing hundreds of channels, which can + // overflow FP16 (and the FP16 accumulation on ANE) for large activations. + std::string meanSq = genVarName(prefix + "_meansq"); + { + std::string axesName = meanSq + "_axes"; + std::string keepName = meanSq + "_keep"; + addIntArrayConstOp(block, axesName, {1}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanSq, {-1, 1, H, W}); + } + // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. + std::string epsName = prefix + "_eps"; + addFloatScalarConstOp(block, epsName, desc.epsilon); + std::string inv = genVarName(prefix + "_inv"); + { + auto* op = block->add_operations(); + op->set_type("rsqrt"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); + (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); + setShape(op, inv, {-1, 1, H, W}); + } + std::string normalized = genVarName(prefix + "_norm"); + emit2("mul", input, inv, normalized, {-1, C, H, W}); + std::string weightName = prefix + "_weight"; + addConstOp(block, weightName, desc.weight, {1, static_cast(C), 1, 1}); + std::string scaled = genVarName(prefix + "_scaled"); + emit2("mul", normalized, weightName, scaled, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + emit2("mul", scaled, mask, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const RMSNormLayerDesc& desc, + const ActivationLayerDesc& act, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int H = m_board_y_size, W = m_board_x_size; + auto emit2 = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + auto reduceSum = [&](const std::string& x, const std::string& out, const std::vector& axes, + const std::vector& dims) { + std::string axesName = out + "_axes"; + std::string keepName = out + "_keep"; + addIntArrayConstOp(block, axesName, axes); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_sum"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, out, dims); + }; + + std::string masked = genVarName(prefix + "_premask"); + emit2("mul", input, mask, masked, {-1, C, H, W}); + std::string sq = genVarName(prefix + "_sq"); + emit2("mul", masked, masked, sq, {-1, C, H, W}); + + std::string meanSq; + std::vector denomDims; + if (desc.spatial) { + // Mean of squares over valid positions and channels. A reduce_sum over C*H*W elements + // overflows FP16 (e.g. trunk tip with large activations on ANE -> inf -> rsqrt 0 -> + // collapse). Instead take reduce_mean over all of C,H,W (masked positions are zero) and + // rescale by totalPositions/validCount to restrict the mean to valid positions. + std::string meanAll = genVarName(prefix + "_meanall"); + { + std::string axesName = meanAll + "_axes", keepName = meanAll + "_keep"; + addIntArrayConstOp(block, axesName, {1, 2, 3}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanAll, {-1, 1, 1, 1}); + } + std::string count = genVarName(prefix + "_count"); + reduceSum(mask, count, {1, 2, 3}, {-1, 1, 1, 1}); // valid positions (<= H*W, no overflow) + std::string totalPosName = prefix + "_totalpos"; + addFloatScalarConstOp(block, totalPosName, static_cast(H * W)); + std::string scaleF = genVarName(prefix + "_scalef"); + emit2("real_div", totalPosName, count, scaleF, {-1, 1, 1, 1}); // totalPos / validCount + meanSq = genVarName(prefix + "_meansq"); + emit2("mul", meanAll, scaleF, meanSq, {-1, 1, 1, 1}); + denomDims = {-1, 1, 1, 1}; + } else { + meanSq = genVarName(prefix + "_meansq"); + std::string axesName = meanSq + "_axes"; + std::string keepName = meanSq + "_keep"; + addIntArrayConstOp(block, axesName, {1}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanSq, {-1, 1, H, W}); + denomDims = {-1, 1, H, W}; + } + + // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. + std::string epsName = prefix + "_eps"; + addFloatScalarConstOp(block, epsName, desc.epsilon); + std::string inv = genVarName(prefix + "_inv"); + { + auto* op = block->add_operations(); + op->set_type("rsqrt"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); + (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); + setShape(op, inv, denomDims); + } + std::string normalized = genVarName(prefix + "_norm"); + emit2("mul", input, inv, normalized, {-1, C, H, W}); + std::string gammaName = prefix + "_gamma"; + std::string betaName = prefix + "_beta"; + addConstOp(block, gammaName, desc.gamma, {1, static_cast(C), 1, 1}); + addConstOp(block, betaName, desc.beta, {1, static_cast(C), 1, 1}); + std::string scaled = genVarName(prefix + "_scaled"); + emit2("mul", normalized, gammaName, scaled, {-1, C, H, W}); + std::string biased = genVarName(prefix + "_biased"); + emit2("add", scaled, betaName, biased, {-1, C, H, W}); + + std::string activated; + if (act.activation_type == ActivationType::Silu) { + activated = genVarName(prefix + "_act"); + addSiluOps(block, biased, activated, 4, C); + } else if (act.activation_type == ActivationType::Mish) { + activated = genVarName(prefix + "_act"); + addMishOps(block, biased, activated, 4, C); + } else if (act.activation_type == ActivationType::ReLU) { + activated = genVarName(prefix + "_act"); + auto* op = block->add_operations(); + op->set_type("relu"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(biased); + setShape(op, activated, {-1, C, H, W}); + } else { + activated = biased; + } + std::string out = genVarName(prefix + "_out"); + emit2("mul", activated, mask, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerAttentionBlockDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.q_proj.in_channels; + const int H = m_board_y_size, W = m_board_x_size; + const int seq = H * W; + const int numHeads = desc.num_heads, numKVHeads = desc.num_kv_heads; + const int qHeadDim = desc.q_head_dim, vHeadDim = desc.v_head_dim; + const int qTotal = numHeads * qHeadDim, kTotal = numKVHeads * qHeadDim, vTotal = numKVHeads * vHeadDim; + + if (numKVHeads != numHeads) { + throw std::runtime_error(desc.name + ": GQA (numKVHeads != numHeads) not supported in CoreML backend"); + } + + auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, + const std::vector& dims) { + std::string shapeName = out + "_shape"; + addIntArrayConstOp(block, shapeName, shapeVals); + auto* op = block->add_operations(); + op->set_type("reshape"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["shape"].add_arguments()->set_name(shapeName); + setShape(op, out, dims); + }; + auto transpose = [&](const std::string& in, const std::string& out, const std::vector& perm, + const std::vector& dims) { + std::string permName = out + "_perm"; + addIntArrayConstOp(block, permName, perm); + auto* op = block->add_operations(); + op->set_type("transpose"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["perm"].add_arguments()->set_name(permName); + setShape(op, out, dims); + }; + auto matmul = [&](const std::string& x, const std::string& y, const std::string& out, + const std::vector& dims, bool transX, bool transY) { + std::string txName = out + "_tx", tyName = out + "_ty"; + addBoolScalarConstOp(block, txName, transX); + addBoolScalarConstOp(block, tyName, transY); + auto* op = block->add_operations(); + op->set_type("matmul"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + (*op->mutable_inputs())["transpose_x"].add_arguments()->set_name(txName); + (*op->mutable_inputs())["transpose_y"].add_arguments()->set_name(tyName); + setShape(op, out, dims); + }; + auto binary = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string normed = addTransformerRMSNorm(block, input, desc.pre_ln, mask, prefix + "_ln"); + std::string nhwc = genVarName(prefix + "_nhwc"); + transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); + std::string x2d = genVarName(prefix + "_x2d"); + reshape(nhwc, x2d, {-1, C}, {-1, C}); + auto proj = [&](const MatMulLayerDesc& w, const std::string& nm, int total) { + std::string wName = nm + "_w"; + addConstOp(block, wName, w.weights, w.getWeightShape()); + std::string out = genVarName(nm); + matmul(x2d, wName, out, {-1, total}, false, false); + return out; + }; + std::string q2d = proj(desc.q_proj, prefix + "_q", qTotal); + std::string k2d = proj(desc.k_proj, prefix + "_k", kTotal); + std::string v2d = proj(desc.v_proj, prefix + "_v", vTotal); + auto toHeads = [&](const std::string& in2d, const std::string& nm, int nh, int hd) { + std::string r = genVarName(nm + "_r"); + reshape(in2d, r, {-1, seq, nh, hd}, {-1, seq, nh, hd}); + std::string t = genVarName(nm + "_t"); + transpose(r, t, {0, 2, 1, 3}, {-1, nh, seq, hd}); + return t; + }; + std::string qh = toHeads(q2d, prefix + "_qh", numHeads, qHeadDim); + std::string kh = toHeads(k2d, prefix + "_kh", numKVHeads, qHeadDim); + std::string vh = toHeads(v2d, prefix + "_vh", numKVHeads, vHeadDim); + + if (desc.use_rope) { + const int numPairs = qHeadDim / 2; + const int numPairsPerDim = numPairs / 2; + const int dimHalf = qHeadDim / 2; + auto applyRope = [&](const std::string& x, int nh, const std::string& tag) { + std::vector cosFull(static_cast(nh) * seq * qHeadDim, 0.0f); + std::vector sinFull(static_cast(nh) * seq * qHeadDim, 0.0f); + for (int h = 0; h < nh; h++) { + int kvh = (h * numKVHeads) / nh; + for (int xy = 0; xy < seq; xy++) { + int y = xy / W; + int x = xy % W; + for (int p = 0; p < numPairs; p++) { + float angle = 0.0f; + if (desc.learnable_rope) { + float fx = desc.rope_freqs[(kvh * numPairs + p) * 2 + 0]; + float fy = desc.rope_freqs[(kvh * numPairs + p) * 2 + 1]; + angle = static_cast(x) * fx + static_cast(y) * fy; + } else { + if (p < numPairsPerDim) { + float freq = 1.0f / std::pow(desc.rope_theta, static_cast(2 * p) / dimHalf); + angle = static_cast(y) * freq; + } else { + int pAdj = p - numPairsPerDim; + float freq = 1.0f / std::pow(desc.rope_theta, static_cast(2 * pAdj) / dimHalf); + angle = static_cast(x) * freq; + } + } + float c = std::cos(angle), s = std::sin(angle); + size_t base = (static_cast(h) * seq + xy) * qHeadDim + 2 * p; + cosFull[base] = c; cosFull[base + 1] = c; + sinFull[base] = s; sinFull[base + 1] = s; + } + } + } + std::vector R(static_cast(qHeadDim) * qHeadDim, 0.0f); + for (int p = 0; p < numPairs; p++) { + R[(2 * p) * qHeadDim + (2 * p + 1)] = 1.0f; + R[(2 * p + 1) * qHeadDim + (2 * p)] = -1.0f; + } + std::string cosName = prefix + "_" + tag + "_cos"; + std::string sinName = prefix + "_" + tag + "_sin"; + std::string rName = prefix + "_" + tag + "_R"; + addConstOp(block, cosName, cosFull, {1, nh, seq, qHeadDim}); + addConstOp(block, sinName, sinFull, {1, nh, seq, qHeadDim}); + // Rank-4 [1,1,qd,qd] so matmul batch dims broadcast cleanly against [B,nh,seq,qd]. + addConstOp(block, rName, R, {1, 1, qHeadDim, qHeadDim}); + std::string rotated = genVarName(prefix + "_" + tag + "_rot"); + matmul(x, rName, rotated, {-1, nh, seq, qHeadDim}, false, false); + std::string xc = genVarName(prefix + "_" + tag + "_xc"); + binary("mul", x, cosName, xc, {-1, nh, seq, qHeadDim}); + std::string rs = genVarName(prefix + "_" + tag + "_rs"); + binary("mul", rotated, sinName, rs, {-1, nh, seq, qHeadDim}); + std::string out = genVarName(prefix + "_" + tag + "_rope"); + binary("add", xc, rs, out, {-1, nh, seq, qHeadDim}); + return out; + }; + qh = applyRope(qh, numHeads, "q"); + kh = applyRope(kh, numKVHeads, "k"); + } + + std::string scores = genVarName(prefix + "_scores"); + matmul(qh, kh, scores, {-1, numHeads, seq, seq}, false, true); + std::string scaleName = prefix + "_scale"; + addFloatScalarConstOp(block, scaleName, 1.0f / std::sqrt(static_cast(qHeadDim))); + std::string scaled = genVarName(prefix + "_sc"); + binary("mul", scores, scaleName, scaled, {-1, numHeads, seq, seq}); + + // mask [B,1,H,W] -> [B,1,1,seq] directly (contiguous reshape; H,W already trailing so the + // row-major flatten gives seq index xy=y*W+x). No transpose -> avoids the reshape-after- + // transpose issue, and is also correct for non-full boards. + std::string maskSeq = genVarName(prefix + "_mseq"); + reshape(mask, maskSeq, {-1, 1, 1, seq}, {-1, 1, 1, seq}); + std::string oneName = prefix + "_one"; + addFloatScalarConstOp(block, oneName, 1.0f); + std::string mm1 = genVarName(prefix + "_mm1"); + binary("sub", maskSeq, oneName, mm1, {-1, 1, 1, seq}); + // Use an FP16-safe magnitude: 1e9 overflows FP16 to +inf, and for valid keys + // (maskSeq-1 == 0) the product 0 * inf becomes NaN, poisoning the whole softmax. + // 1e4 is well within FP16 range and exp(score - 1e4) still underflows to 0. + std::string bigName = prefix + "_big"; + addFloatScalarConstOp(block, bigName, 1.0e4f); + std::string keyBias = genVarName(prefix + "_kb"); + binary("mul", mm1, bigName, keyBias, {-1, 1, 1, seq}); + std::string scoresMasked = genVarName(prefix + "_scm"); + binary("add", scaled, keyBias, scoresMasked, {-1, numHeads, seq, seq}); + + std::string attn = genVarName(prefix + "_attn"); + { + std::string axisName = attn + "_axis"; + addIntScalarConstOp(block, axisName, 3); + auto* op = block->add_operations(); + op->set_type("softmax"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(scoresMasked); + (*op->mutable_inputs())["axis"].add_arguments()->set_name(axisName); + setShape(op, attn, {-1, numHeads, seq, seq}); + } + + std::string attnOut = genVarName(prefix + "_ao"); + matmul(attn, vh, attnOut, {-1, numHeads, seq, vHeadDim}, false, false); + + // Output projection, done per-head to avoid reshape-after-transpose: CoreML's reshape + // ignores an immediately-preceding transpose, so merging [head,dim]->channels after a + // transpose scrambles the data. Instead slice each head from attnOut (head is the + // contiguous axis 1), reshape (leading-merge only), matmul its weight slice, and sum. + // out[b,s,c] = sum_h sum_d attnOut[b,h,s,d] * outProj.weights[(h*vHeadDim+d)*outC + c] + const int outC = desc.out_proj.out_channels; + std::string proj2d; + for (int h = 0; h < numHeads; h++) { + std::string aoh = genVarName(prefix + "_aoh"); + { + std::string beginName = aoh + "_begin", sizeName = aoh + "_size"; + addIntArrayConstOp(block, beginName, {0, h, 0, 0}); + addIntArrayConstOp(block, sizeName, {-1, 1, seq, vHeadDim}); + auto* op = block->add_operations(); + op->set_type("slice_by_size"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(attnOut); + (*op->mutable_inputs())["begin"].add_arguments()->set_name(beginName); + (*op->mutable_inputs())["size"].add_arguments()->set_name(sizeName); + setShape(op, aoh, {-1, 1, seq, vHeadDim}); + } + std::string aoh2d = genVarName(prefix + "_aoh2d"); + reshape(aoh, aoh2d, {-1, vHeadDim}, {-1, vHeadDim}); // [B*seq, vHeadDim] + std::string wh = prefix + "_ow" + std::to_string(h); + std::vector whData(static_cast(vHeadDim) * outC); + for (int d = 0; d < vHeadDim; d++) + for (int c = 0; c < outC; c++) + whData[d * outC + c] = desc.out_proj.weights[static_cast(h * vHeadDim + d) * outC + c]; + addConstOp(block, wh, whData, {vHeadDim, outC}); + std::string contrib = genVarName(prefix + "_contrib"); + matmul(aoh2d, wh, contrib, {-1, outC}, false, false); + if (h == 0) { + proj2d = contrib; + } else { + std::string acc = genVarName(prefix + "_acc"); + binary("add", proj2d, contrib, acc, {-1, outC}); + proj2d = acc; + } + } + std::string projNHWC = genVarName(prefix + "_pnhwc"); + reshape(proj2d, projNHWC, {-1, H, W, C}, {-1, H, W, C}); + std::string projNCHW = genVarName(prefix + "_pnchw"); + transpose(projNHWC, projNCHW, {0, 3, 1, 2}, {-1, C, H, W}); + std::string maskedOut = genVarName(prefix + "_masked"); + binary("mul", projNCHW, mask, maskedOut, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + binary("add", input, maskedOut, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerFFNBlockDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int ffn = desc.ffn_channels; + const int H = m_board_y_size, W = m_board_x_size; + + if (!desc.use_swiglu) { + throw std::runtime_error(desc.name + ": non-SwiGLU transformer FFN not supported in CoreML backend"); + } + + auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, + const std::vector& dims) { + std::string shapeName = out + "_shape"; + addIntArrayConstOp(block, shapeName, shapeVals); + auto* op = block->add_operations(); + op->set_type("reshape"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["shape"].add_arguments()->set_name(shapeName); + setShape(op, out, dims); + }; + auto transpose = [&](const std::string& in, const std::string& out, const std::vector& perm, + const std::vector& dims) { + std::string permName = out + "_perm"; + addIntArrayConstOp(block, permName, perm); + auto* op = block->add_operations(); + op->set_type("transpose"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["perm"].add_arguments()->set_name(permName); + setShape(op, out, dims); + }; + auto matmul = [&](const std::string& x, const std::string& y, const std::string& out, + const std::vector& dims) { + std::string txName = out + "_tx", tyName = out + "_ty"; + addBoolScalarConstOp(block, txName, false); + addBoolScalarConstOp(block, tyName, false); + auto* op = block->add_operations(); + op->set_type("matmul"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + (*op->mutable_inputs())["transpose_x"].add_arguments()->set_name(txName); + (*op->mutable_inputs())["transpose_y"].add_arguments()->set_name(tyName); + setShape(op, out, dims); + }; + auto binary = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string normed = addTransformerRMSNorm(block, input, desc.pre_ln, mask, prefix + "_ln"); + std::string nhwc = genVarName(prefix + "_nhwc"); + transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); + std::string x2d = genVarName(prefix + "_x2d"); + reshape(nhwc, x2d, {-1, C}, {-1, C}); + + std::string w1 = prefix + "_w1"; + addConstOp(block, w1, desc.linear1.weights, desc.linear1.getWeightShape()); + std::string a = genVarName(prefix + "_a"); + matmul(x2d, w1, a, {-1, ffn}); + std::string wg = prefix + "_wg"; + addConstOp(block, wg, desc.linear_gate.weights, desc.linear_gate.getWeightShape()); + std::string g = genVarName(prefix + "_g"); + matmul(x2d, wg, g, {-1, ffn}); + + std::string sig = genVarName(prefix + "_sig"); + { + auto* op = block->add_operations(); + op->set_type("sigmoid"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(a); + setShape(op, sig, {-1, ffn}); + } + std::string siluA = genVarName(prefix + "_silu"); + binary("mul", a, sig, siluA, {-1, ffn}); + std::string h = genVarName(prefix + "_h"); + binary("mul", siluA, g, h, {-1, ffn}); + + std::string w2 = prefix + "_w2"; + addConstOp(block, w2, desc.linear2.weights, desc.linear2.getWeightShape()); + std::string o = genVarName(prefix + "_o"); + matmul(h, w2, o, {-1, C}); + + std::string oNHWC = genVarName(prefix + "_onhwc"); + reshape(o, oNHWC, {-1, H, W, C}, {-1, H, W, C}); + std::string oNCHW = genVarName(prefix + "_onchw"); + transpose(oNHWC, oNCHW, {0, 3, 1, 2}, {-1, C, H, W}); + std::string maskedOut = genVarName(prefix + "_masked"); + binary("mul", oNCHW, mask, maskedOut, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + binary("add", input, maskedOut, out, {-1, C, H, W}); + return out; +} + std::string MILBuilder::buildTrunk(CoreML::Specification::MILSpec::Block* block, const std::string& spatial_input, const std::string& global_input, @@ -1747,12 +2321,23 @@ std::string MILBuilder::buildTrunk(CoreML::Specification::MILSpec::Block* block, } else if (entry.block_kind == NESTED_BOTTLENECK_BLOCK_KIND) { const auto& block_desc = std::get(*entry.block); x = buildNestedBottleneckBlock(block, x, block_desc, mask, prefix); + } else if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + const auto& block_desc = std::get(*entry.block); + x = buildTransformerAttentionBlock(block, x, block_desc, mask, prefix); + } else if (entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + const auto& block_desc = std::get(*entry.block); + x = buildTransformerFFNBlock(block, x, block_desc, mask, prefix); } } // Trunk tip - std::string trunk_out = genVarName("trunk_tip"); - addBatchNormActivationOps(block, x, trunk.trunk_tip_bn, trunk.trunk_tip_activation, mask, trunk_out); + std::string trunk_out; + if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) { + trunk_out = genVarName("trunk_tip"); + addBatchNormActivationOps(block, x, trunk.trunk_tip_bn, trunk.trunk_tip_activation, mask, trunk_out); + } else { + trunk_out = addTrunkRMSNorm(block, x, trunk.trunk_tip_rms_norm, trunk.trunk_tip_activation, mask, "trunk_tip_rms"); + } return trunk_out; } @@ -1898,6 +2483,12 @@ std::string MILBuilder::buildNestedBottleneckBlock(CoreML::Specification::MILSpe } else if (entry.block_kind == GLOBAL_POOLING_BLOCK_KIND) { const auto& nested = std::get(*entry.block); x = buildGlobalPoolingResidualBlock(block, x, nested, mask, nested_prefix); + } else if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + const auto& nested = std::get(*entry.block); + x = buildTransformerAttentionBlock(block, x, nested, mask, nested_prefix); + } else if (entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + const auto& nested = std::get(*entry.block); + x = buildTransformerFFNBlock(block, x, nested, mask, nested_prefix); } } diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index 042f9fc16..5d25b963a 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -120,6 +120,44 @@ class MILBuilder { int rank, int channels); + void addSiluOps(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& output, + int rank, + int channels); + + // Generic output-shape setter: dims with -1 entries become unknown/dynamic dimensions. + void setShape(CoreML::Specification::MILSpec::Operation* op, + const std::string& name, + const std::vector& dims); + + // Lightweight transformer RMSNorm (weight only, per-position over channels). NCHW in/out. + std::string addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerRMSNormDesc& desc, + const std::string& mask, + const std::string& prefix); + + // Full RMSNorm at trunk tip: gamma/beta, spatial or per-position, fused activation. NCHW in/out. + std::string addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const RMSNormLayerDesc& desc, + const ActivationLayerDesc& act, + const std::string& mask, + const std::string& prefix); + + std::string buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerAttentionBlockDesc& block_desc, + const std::string& mask, + const std::string& prefix); + + std::string buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerFFNBlockDesc& block_desc, + const std::string& mask, + const std::string& prefix); + void addGlobalPoolingOps(CoreML::Specification::MILSpec::Block* block, const std::string& input, const std::string& mask, diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp index 2d06c27e5..5dcb80f5d 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp @@ -315,6 +315,8 @@ ActivationLayerDesc KataGoParser::parseActivationLayer(int model_version) { layer.activation_type = ActivationType::ReLU; } else if (activation_str == "ACTIVATION_MISH") { layer.activation_type = ActivationType::Mish; + } else if (activation_str == "ACTIVATION_SILU") { + layer.activation_type = ActivationType::Silu; } else { throw std::runtime_error("Unknown activation type: " + activation_str); } @@ -420,6 +422,98 @@ static void checkBlockChannels(const std::string& block_name, const std::string& } } +TransformerRMSNormDesc KataGoParser::parseTransformerRMSNorm() { + TransformerRMSNormDesc layer; + layer.name = readString(); + layer.num_channels = readInt(); + layer.epsilon = readFloat(); + if (layer.num_channels < 1) { + throw std::runtime_error(layer.name + ": transformer rmsnorm numChannels must be >= 1"); + } + layer.weight = readFloats(layer.num_channels, layer.name + "/weight"); + return layer; +} + +RMSNormLayerDesc KataGoParser::parseRMSNormLayer() { + RMSNormLayerDesc layer; + layer.name = readString(); + layer.num_channels = readInt(); + layer.epsilon = readFloat(); + layer.spatial = (readInt() != 0); + layer.cgroup_size = readInt(); + if (layer.num_channels < 1) { + throw std::runtime_error(layer.name + ": rmsnorm numChannels must be >= 1"); + } + if (layer.cgroup_size != 0) { + throw std::runtime_error(layer.name + ": grouped spatial RMSNorm is not supported"); + } + layer.gamma = readFloats(layer.num_channels, layer.name + "/gamma"); + layer.beta = readFloats(layer.num_channels, layer.name + "/beta"); + return layer; +} + +TransformerAttentionBlockDesc KataGoParser::parseTransformerAttentionBlock(int model_version) { + TransformerAttentionBlockDesc block; + block.name = readString(); + block.num_heads = readInt(); + block.num_kv_heads = readInt(); + block.q_head_dim = readInt(); + block.v_head_dim = readInt(); + block.use_rope = (readInt() != 0); + block.learnable_rope = (readInt() != 0); + + if (block.num_heads < 1 || block.num_kv_heads < 1 || (block.num_heads % block.num_kv_heads != 0)) { + throw std::runtime_error(block.name + ": invalid numHeads/numKVHeads"); + } + if (block.use_rope && (block.q_head_dim % 2 != 0)) { + throw std::runtime_error(block.name + ": qHeadDim must be even when RoPE is used"); + } + + block.pre_ln = parseTransformerRMSNorm(); + block.q_proj = parseMatMulLayer(); + block.k_proj = parseMatMulLayer(); + block.v_proj = parseMatMulLayer(); + block.out_proj = parseMatMulLayer(); + + if (block.use_rope) { + if (block.learnable_rope) { + readString(); // ropeFreqs name + block.rope_num_kv_heads = readInt(); + block.rope_num_pairs = readInt(); + int rope_dim2 = readInt(); + if (block.rope_num_kv_heads != block.num_kv_heads || + block.rope_num_pairs != block.q_head_dim / 2 || rope_dim2 != 2) { + throw std::runtime_error(block.name + ": invalid learnable rope header"); + } + block.rope_freqs = readFloats( + static_cast(block.rope_num_kv_heads) * block.rope_num_pairs * 2, + block.name + "/rope_freqs"); + } else { + readString(); // ropeTheta name + block.rope_theta = readFloat(); + } + } + return block; +} + +TransformerFFNBlockDesc KataGoParser::parseTransformerFFNBlock(int model_version) { + TransformerFFNBlockDesc block; + block.name = readString(); + block.num_channels = readInt(); + block.ffn_channels = readInt(); + block.use_swiglu = (readInt() != 0); + if (block.num_channels < 1 || block.ffn_channels < 1) { + throw std::runtime_error(block.name + ": transformer ffn channels must be positive"); + } + block.pre_ln = parseTransformerRMSNorm(); + block.linear1 = parseMatMulLayer(); + if (block.use_swiglu) { + block.linear_gate = parseMatMulLayer(); + } + block.linear2 = parseMatMulLayer(); + return block; +} + std::vector KataGoParser::parseBlockStack(int model_version, int num_blocks, int trunk_num_channels) { std::vector blocks; blocks.reserve(num_blocks); @@ -449,6 +543,14 @@ std::vector KataGoParser::parseBlockStack(int model_version, int num desc.pre_bn.num_channels, desc.post_conv.out_channels, trunk_num_channels); entry.block = std::make_shared(std::move(desc)); + } else if (block_kind_name == "transformer_attention_block") { + entry.block_kind = TRANSFORMER_ATTENTION_BLOCK_KIND; + auto desc = parseTransformerAttentionBlock(model_version); + entry.block = std::make_shared(std::move(desc)); + } else if (block_kind_name == "transformer_ffn_block") { + entry.block_kind = TRANSFORMER_FFN_BLOCK_KIND; + auto desc = parseTransformerFFNBlock(model_version); + entry.block = std::make_shared(std::move(desc)); } else { throw std::runtime_error("Unknown block kind: " + block_kind_name); } @@ -505,11 +607,17 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version) std::to_string(trunk.gpool_num_channels) + ")"); } - // Version >= 15 has 6 unused int parameters + // Version >= 15: first int is trunkNormKind, followed by 5 unused ints. if (model_version >= 15) { - for (int i = 0; i < 6; i++) { + trunk.trunk_norm_kind = readInt(); + for (int i = 0; i < 5; i++) { readInt(); } + if (trunk.trunk_norm_kind != TRUNK_NORM_KIND_STANDARD && + trunk.trunk_norm_kind != TRUNK_NORM_KIND_RMSNORM) { + throw std::runtime_error(trunk.name + ": unknown trunk norm kind: " + + std::to_string(trunk.trunk_norm_kind)); + } } trunk.initial_conv = parseConvLayer(); @@ -548,14 +656,24 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version) // Parse residual blocks trunk.blocks = parseBlockStack(model_version, trunk.num_blocks, trunk.trunk_num_channels); - trunk.trunk_tip_bn = parseBatchNormLayer(); - trunk.trunk_tip_activation = parseActivationLayer(model_version); - if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) { - throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" + - std::to_string(trunk.trunk_tip_bn.num_channels) + - ") != trunkNumChannels (" + - std::to_string(trunk.trunk_num_channels) + ")"); + if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) { + trunk.trunk_tip_bn = parseBatchNormLayer(); + if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) { + throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" + + std::to_string(trunk.trunk_tip_bn.num_channels) + + ") != trunkNumChannels (" + + std::to_string(trunk.trunk_num_channels) + ")"); + } + } else { + trunk.trunk_tip_rms_norm = parseRMSNormLayer(); + if (trunk.trunk_tip_rms_norm.num_channels != trunk.trunk_num_channels) { + throw std::runtime_error(trunk.name + ": trunkTipRMSNorm.numChannels (" + + std::to_string(trunk.trunk_tip_rms_norm.num_channels) + + ") != trunkNumChannels (" + + std::to_string(trunk.trunk_num_channels) + ")"); + } } + trunk.trunk_tip_activation = parseActivationLayer(model_version); return trunk; } diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp index cbcfdefa8..9c9935b57 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp @@ -50,11 +50,15 @@ class KataGoParser { ActivationLayerDesc parseActivationLayer(int model_version); MatMulLayerDesc parseMatMulLayer(); MatBiasLayerDesc parseMatBiasLayer(); + TransformerRMSNormDesc parseTransformerRMSNorm(); + RMSNormLayerDesc parseRMSNormLayer(); // Block parsing functions ResidualBlockDesc parseResidualBlock(int model_version); GlobalPoolingResidualBlockDesc parseGlobalPoolingResidualBlock(int model_version); NestedBottleneckResidualBlockDesc parseNestedBottleneckBlock(int model_version, int trunk_num_channels); + TransformerAttentionBlockDesc parseTransformerAttentionBlock(int model_version); + TransformerFFNBlockDesc parseTransformerFFNBlock(int model_version); std::vector parseBlockStack(int model_version, int num_blocks, int trunk_num_channels); // Component parsing functions diff --git a/cpp/external/katagocoreml/src/types/KataGoTypes.hpp b/cpp/external/katagocoreml/src/types/KataGoTypes.hpp index 284b26cd3..1e0ae5f12 100644 --- a/cpp/external/katagocoreml/src/types/KataGoTypes.hpp +++ b/cpp/external/katagocoreml/src/types/KataGoTypes.hpp @@ -19,10 +19,15 @@ namespace katagocoreml { enum class ActivationType : int { Identity = 0, ReLU = 1, - Mish = 2 + Mish = 2, + Silu = 3 // MISH_SCALE8 = 12 is internal optimization, treated as Mish }; +/// Trunk normalization kind (matching KataGo's desc.h) +constexpr int TRUNK_NORM_KIND_STANDARD = 0; +constexpr int TRUNK_NORM_KIND_RMSNORM = 1; + // ============================================================================ // Block Kind Constants // ============================================================================ @@ -31,6 +36,8 @@ enum class ActivationType : int { constexpr int ORDINARY_BLOCK_KIND = 0; constexpr int GLOBAL_POOLING_BLOCK_KIND = 2; constexpr int NESTED_BOTTLENECK_BLOCK_KIND = 3; +constexpr int TRANSFORMER_ATTENTION_BLOCK_KIND = 4; +constexpr int TRANSFORMER_FFN_BLOCK_KIND = 5; // ============================================================================ // Layer Descriptors @@ -98,6 +105,25 @@ struct MatBiasLayerDesc { std::vector weights; // Shape: [num_channels] }; +/// Lightweight RMSNorm used inside transformer blocks (weight only, no bias). +struct TransformerRMSNormDesc { + std::string name; + int num_channels = 0; + float epsilon = 1e-6f; + std::vector weight; // Shape: [num_channels] +}; + +/// Full-featured RMSNorm (gamma/beta, spatial mode) used at the trunk tip. +struct RMSNormLayerDesc { + std::string name; + int num_channels = 0; + float epsilon = 1e-6f; + bool spatial = false; + int cgroup_size = 0; + std::vector gamma; // Shape: [num_channels] + std::vector beta; // Shape: [num_channels] +}; + // ============================================================================ // Block Descriptors // ============================================================================ @@ -106,12 +132,16 @@ struct MatBiasLayerDesc { struct ResidualBlockDesc; struct GlobalPoolingResidualBlockDesc; struct NestedBottleneckResidualBlockDesc; +struct TransformerAttentionBlockDesc; +struct TransformerFFNBlockDesc; /// Block descriptor variant using BlockDesc = std::variant< ResidualBlockDesc, GlobalPoolingResidualBlockDesc, - NestedBottleneckResidualBlockDesc + NestedBottleneckResidualBlockDesc, + TransformerAttentionBlockDesc, + TransformerFFNBlockDesc >; /// Block with its kind @@ -165,6 +195,38 @@ struct NestedBottleneckResidualBlockDesc { ConvLayerDesc post_conv; }; +/// Transformer self-attention block descriptor (pre-norm, multi-head, optional 2D RoPE, GQA). +struct TransformerAttentionBlockDesc { + std::string name; + int num_heads = 0; + int num_kv_heads = 0; + int q_head_dim = 0; + int v_head_dim = 0; + bool use_rope = false; + bool learnable_rope = false; + TransformerRMSNormDesc pre_ln; + MatMulLayerDesc q_proj; + MatMulLayerDesc k_proj; + MatMulLayerDesc v_proj; + MatMulLayerDesc out_proj; + int rope_num_kv_heads = 0; + int rope_num_pairs = 0; + std::vector rope_freqs; // learnable: (num_kv_heads, num_pairs, 2) flattened + float rope_theta = 0.0f; +}; + +/// Transformer feed-forward (SwiGLU) block descriptor. +struct TransformerFFNBlockDesc { + std::string name; + int num_channels = 0; + int ffn_channels = 0; + bool use_swiglu = false; + TransformerRMSNormDesc pre_ln; + MatMulLayerDesc linear1; + MatMulLayerDesc linear_gate; // only used when use_swiglu + MatMulLayerDesc linear2; +}; + // ============================================================================ // SGF Metadata Encoder (v15+) // ============================================================================ @@ -202,7 +264,9 @@ struct TrunkDesc { MatMulLayerDesc initial_matmul; std::optional sgf_metadata_encoder; std::vector blocks; + int trunk_norm_kind = TRUNK_NORM_KIND_STANDARD; BatchNormLayerDesc trunk_tip_bn; + RMSNormLayerDesc trunk_tip_rms_norm; ActivationLayerDesc trunk_tip_activation; }; diff --git a/cpp/neuralnet/metalbackend.cpp b/cpp/neuralnet/metalbackend.cpp index 77a2d45c9..6a1cecabc 100644 --- a/cpp/neuralnet/metalbackend.cpp +++ b/cpp/neuralnet/metalbackend.cpp @@ -130,6 +130,8 @@ ActivationKind activationLayerDescToSwift(const ActivationLayerDesc* desc) { return ActivationKind::mish(); case ACTIVATION_MISH_SCALE8: return ActivationKind::identity(); // Metal/CoreML does not use scaled mish + case ACTIVATION_SILU: + return ActivationKind::silu(); case ACTIVATION_IDENTITY: return ActivationKind::identity(); default: @@ -217,6 +219,58 @@ SWNestedBottleneckResidualBlockDesc nestedBottleneckResidualBlockDescToSwift(con postConv); } +/// Convert a transformer RMSNorm description from C++ to Swift +SWTransformerRMSNormDesc transformerRMSNormDescToSwift(const TransformerRMSNormDesc* desc) { + return createSWTransformerRMSNormDesc( + desc->numChannels, + desc->epsilon, + (float*)desc->weight.data()); +} + +/// Convert a transformer attention block description from C++ to Swift +SWTransformerAttentionBlockDesc transformerAttentionBlockDescToSwift(const TransformerAttentionDesc* desc) { + SWTransformerRMSNormDesc preLN = transformerRMSNormDescToSwift(&desc->preLN); + SWMatMulLayerDesc qProj = matMulLayerDescToSwift(&desc->qProj); + SWMatMulLayerDesc kProj = matMulLayerDescToSwift(&desc->kProj); + SWMatMulLayerDesc vProj = matMulLayerDescToSwift(&desc->vProj); + SWMatMulLayerDesc outProj = matMulLayerDescToSwift(&desc->outProj); + float* ropeFreqs = desc->ropeFreqs.empty() ? nullptr : (float*)desc->ropeFreqs.data(); + + return createSWTransformerAttentionBlockDesc( + desc->numHeads, + desc->numKVHeads, + desc->qHeadDim, + desc->vHeadDim, + desc->useRope, + desc->learnableRope, + preLN, + qProj, + kProj, + vProj, + outProj, + desc->ropeNumKVHeads, + desc->ropeNumPairs, + ropeFreqs, + desc->ropeTheta); +} + +/// Convert a transformer FFN block description from C++ to Swift +SWTransformerFFNBlockDesc transformerFFNBlockDescToSwift(const TransformerFFNDesc* desc) { + SWTransformerRMSNormDesc preLN = transformerRMSNormDescToSwift(&desc->preLN); + SWMatMulLayerDesc linear1 = matMulLayerDescToSwift(&desc->linear1); + SWMatMulLayerDesc linearGate = matMulLayerDescToSwift(&desc->linearGate); + SWMatMulLayerDesc linear2 = matMulLayerDescToSwift(&desc->linear2); + + return createSWTransformerFFNBlockDesc( + desc->numChannels, + desc->ffnChannels, + desc->useSwiGLU, + preLN, + linear1, + linearGate, + linear2); +} + /// Convert residual blocks from C++ to Swift swift::Array residualBlocksToSwift(const vector>& blocks) { auto builder = createBlockDescriptorBuilder(); @@ -230,9 +284,12 @@ swift::Array residualBlocksToSwift(const vector sGFMetadataEncoderDescToSwift(const SG } /// Convert a trunk description from C++ to Swift +SWRMSNormLayerDesc rmsNormLayerDescToSwift(const RMSNormLayerDesc* desc) { + float* gamma = desc->gamma.empty() ? nullptr : (float*)desc->gamma.data(); + float* beta = desc->beta.empty() ? nullptr : (float*)desc->beta.data(); + return createSWRMSNormLayerDesc( + desc->numChannels, + desc->epsilon, + desc->spatial, + gamma, + beta); +} + SWTrunkDesc trunkDescToSwift(const TrunkDesc* trunk) { SWConvLayerDesc initialConv = convLayerDescToSwift(&trunk->initialConv); SWMatMulLayerDesc initialMatMul = matMulLayerDescToSwift(&trunk->initialMatMul); auto sgfMetadataEncoder = sGFMetadataEncoderDescToSwift(&trunk->sgfMetadataEncoder); auto swBlocks = residualBlocksToSwift(trunk->blocks); - if(trunk->trunkNormKind != TRUNK_NORM_KIND_STANDARD) - throw StringError("Trunk RMSNorm is not yet supported by the Metal backend"); SWBatchNormLayerDesc trunkTipBN = batchNormLayerDescToSwift(&trunk->trunkTipBN); + SWRMSNormLayerDesc trunkTipRMSNorm = rmsNormLayerDescToSwift(&trunk->trunkTipRMSNorm); ActivationKind trunkTipActivation = activationLayerDescToSwift(&trunk->trunkTipActivation); return createSWTrunkDesc( @@ -285,7 +352,9 @@ SWTrunkDesc trunkDescToSwift(const TrunkDesc* trunk) { initialMatMul, sgfMetadataEncoder, swBlocks, + trunk->trunkNormKind, trunkTipBN, + trunkTipRMSNorm, trunkTipActivation); } diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift index bbd2255bc..f275c7925 100644 --- a/cpp/neuralnet/metallayers.swift +++ b/cpp/neuralnet/metallayers.swift @@ -76,6 +76,11 @@ extension MPSGraph { return mulTensor } + + /// SiLU / Swish activation: x * sigmoid(x). Numerically stable across FP16/FP32. + func silu(tensor: MPSGraphTensor) -> MPSGraphTensor { + return multiplication(tensor, sigmoid(with: tensor, name: nil), name: nil) + } } // MARK: - Input Shape Utilities @@ -358,6 +363,7 @@ public enum ActivationKind { case identity case relu case mish + case silu } /// A struct that represents a description of convolutional layer. @@ -487,6 +493,63 @@ public func createSWMatBiasLayerDesc( weights: weights) } +/// A lightweight RMSNorm description used inside transformer blocks (weight only, no bias). +public struct SWTransformerRMSNormDesc { + let numChannels: NSNumber + let epsilon: Float + let weight: UnsafeMutablePointer + + init(numChannels: NSNumber, epsilon: Float, weight: UnsafeMutablePointer) { + self.numChannels = numChannels + self.epsilon = epsilon + self.weight = weight + } +} + +public func createSWTransformerRMSNormDesc( + numChannels: Int32, + epsilon: Float, + weight: UnsafeMutablePointer +) -> SWTransformerRMSNormDesc { + return SWTransformerRMSNormDesc( + numChannels: numChannels as NSNumber, + epsilon: epsilon, + weight: weight) +} + +/// A full-featured RMSNorm description (gamma/beta, spatial mode), used at the trunk tip. +public struct SWRMSNormLayerDesc { + let numChannels: NSNumber + let epsilon: Float + let spatial: Bool + let gamma: UnsafeMutablePointer? + let beta: UnsafeMutablePointer? + + init(numChannels: NSNumber, epsilon: Float, spatial: Bool, + gamma: UnsafeMutablePointer?, beta: UnsafeMutablePointer?) { + self.numChannels = numChannels + self.epsilon = epsilon + self.spatial = spatial + self.gamma = gamma + self.beta = beta + } +} + +public func createSWRMSNormLayerDesc( + numChannels: Int32, + epsilon: Float, + spatial: Bool, + gamma: UnsafeMutablePointer?, + beta: UnsafeMutablePointer? +) -> SWRMSNormLayerDesc { + return SWRMSNormLayerDesc( + numChannels: numChannels as NSNumber, + epsilon: epsilon, + spatial: spatial, + gamma: gamma, + beta: beta) +} + // MARK: - Core Layers /// A class that represents a convolutional layer using MPSGraph @@ -612,6 +675,8 @@ struct ActivationLayer { resultTensor = graph.reLU(with: sourceTensor, name: nil) case .mish: resultTensor = graph.mish(tensor: sourceTensor) + case .silu: + resultTensor = graph.silu(tensor: sourceTensor) default: resultTensor = sourceTensor } @@ -987,6 +1052,140 @@ public func createSWNestedBottleneckResidualBlockDesc( postConv: postConv) } +public class SWTransformerAttentionBlockDesc: BlockDescriptor { + let numHeads: Int + let numKVHeads: Int + let qHeadDim: Int + let vHeadDim: Int + let useRope: Bool + let learnableRope: Bool + let preLN: SWTransformerRMSNormDesc + let qProj: SWMatMulLayerDesc + let kProj: SWMatMulLayerDesc + let vProj: SWMatMulLayerDesc + let outProj: SWMatMulLayerDesc + let ropeNumKVHeads: Int + let ropeNumPairs: Int + let ropeFreqs: UnsafeMutablePointer? // learnable: (numKVHeads, numPairs, 2) flattened + let ropeTheta: Float + + init( + numHeads: Int, + numKVHeads: Int, + qHeadDim: Int, + vHeadDim: Int, + useRope: Bool, + learnableRope: Bool, + preLN: SWTransformerRMSNormDesc, + qProj: SWMatMulLayerDesc, + kProj: SWMatMulLayerDesc, + vProj: SWMatMulLayerDesc, + outProj: SWMatMulLayerDesc, + ropeNumKVHeads: Int, + ropeNumPairs: Int, + ropeFreqs: UnsafeMutablePointer?, + ropeTheta: Float + ) { + self.numHeads = numHeads + self.numKVHeads = numKVHeads + self.qHeadDim = qHeadDim + self.vHeadDim = vHeadDim + self.useRope = useRope + self.learnableRope = learnableRope + self.preLN = preLN + self.qProj = qProj + self.kProj = kProj + self.vProj = vProj + self.outProj = outProj + self.ropeNumKVHeads = ropeNumKVHeads + self.ropeNumPairs = ropeNumPairs + self.ropeFreqs = ropeFreqs + self.ropeTheta = ropeTheta + } +} + +public func createSWTransformerAttentionBlockDesc( + numHeads: Int32, + numKVHeads: Int32, + qHeadDim: Int32, + vHeadDim: Int32, + useRope: Bool, + learnableRope: Bool, + preLN: SWTransformerRMSNormDesc, + qProj: SWMatMulLayerDesc, + kProj: SWMatMulLayerDesc, + vProj: SWMatMulLayerDesc, + outProj: SWMatMulLayerDesc, + ropeNumKVHeads: Int32, + ropeNumPairs: Int32, + ropeFreqs: UnsafeMutablePointer?, + ropeTheta: Float +) -> SWTransformerAttentionBlockDesc { + return SWTransformerAttentionBlockDesc( + numHeads: Int(numHeads), + numKVHeads: Int(numKVHeads), + qHeadDim: Int(qHeadDim), + vHeadDim: Int(vHeadDim), + useRope: useRope, + learnableRope: learnableRope, + preLN: preLN, + qProj: qProj, + kProj: kProj, + vProj: vProj, + outProj: outProj, + ropeNumKVHeads: Int(ropeNumKVHeads), + ropeNumPairs: Int(ropeNumPairs), + ropeFreqs: ropeFreqs, + ropeTheta: ropeTheta) +} + +public class SWTransformerFFNBlockDesc: BlockDescriptor { + let numChannels: Int + let ffnChannels: Int + let useSwiGLU: Bool + let preLN: SWTransformerRMSNormDesc + let linear1: SWMatMulLayerDesc + let linearGate: SWMatMulLayerDesc + let linear2: SWMatMulLayerDesc + + init( + numChannels: Int, + ffnChannels: Int, + useSwiGLU: Bool, + preLN: SWTransformerRMSNormDesc, + linear1: SWMatMulLayerDesc, + linearGate: SWMatMulLayerDesc, + linear2: SWMatMulLayerDesc + ) { + self.numChannels = numChannels + self.ffnChannels = ffnChannels + self.useSwiGLU = useSwiGLU + self.preLN = preLN + self.linear1 = linear1 + self.linearGate = linearGate + self.linear2 = linear2 + } +} + +public func createSWTransformerFFNBlockDesc( + numChannels: Int32, + ffnChannels: Int32, + useSwiGLU: Bool, + preLN: SWTransformerRMSNormDesc, + linear1: SWMatMulLayerDesc, + linearGate: SWMatMulLayerDesc, + linear2: SWMatMulLayerDesc +) -> SWTransformerFFNBlockDesc { + return SWTransformerFFNBlockDesc( + numChannels: Int(numChannels), + ffnChannels: Int(ffnChannels), + useSwiGLU: useSwiGLU, + preLN: preLN, + linear1: linear1, + linearGate: linearGate, + linear2: linear2) +} + public class BlockDescriptorBuilder { public var blockDescriptors: [BlockDescriptor] = [] @@ -1001,6 +1200,316 @@ public func createBlockDescriptorBuilder() -> BlockDescriptorBuilder { return BlockDescriptorBuilder() } +// MARK: - Transformer Layers + +/// Lightweight RMSNorm used inside transformer blocks (weight only, no bias). +/// Input/output are NCHW [B, C, H, W]. Normalizes across channels per spatial position, +/// scales by per-channel weight, and masks the output. +struct TransformerRMSNormLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerRMSNormDesc + ) { + let numChannels = descriptor.numChannels + let dataType = sourceTensor.dataType + + // meanSq over channel axis (1): [B,1,H,W] + let sq = graph.square(with: sourceTensor, name: nil) + let sumSq = graph.reductionSum(with: sq, axis: 1, name: nil) + let invC = graph.constant(1.0 / numChannels.doubleValue, dataType: dataType) + let meanSq = graph.multiplication(sumSq, invC, name: nil) + let epsTensor = graph.constant(Double(descriptor.epsilon), dataType: dataType) + let denom = graph.squareRoot(with: graph.addition(meanSq, epsTensor, name: nil), name: nil) + let normalized = graph.division(sourceTensor, denom, name: nil) + + // scale by per-channel weight [1, C, 1, 1] + let weightShape: [NSNumber] = [1, numChannels, 1, 1] + let weightData = Data(floatsNoCopy: descriptor.weight, shape: weightShape) + let weightTensor = graph.constant(weightData, shape: weightShape, dataType: dataType) + let scaled = graph.multiplication(normalized, weightTensor, name: nil) + + resultTensor = graph.multiplication(scaled, maskTensor, name: nil) + } +} + +/// Full-featured RMSNorm for the trunk tip: gamma/beta, spatial or per-position mode, and a +/// fused activation. Input/output are NCHW [B, C, H, W]. Mirrors the Eigen RMSNormLayer. +struct TrunkRMSNormLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWRMSNormLayerDesc, + activationKind: ActivationKind + ) { + let dataType = sourceTensor.dataType + let numChannels = descriptor.numChannels + + // Zero invalid positions before accumulating sum of squares. + let masked = graph.multiplication(sourceTensor, maskTensor, name: nil) + let sq = graph.square(with: masked, name: nil) + + let meanSq: MPSGraphTensor + if descriptor.spatial { + // Normalize over channels AND valid spatial positions per batch element. + let sumSq = graph.reductionSum(with: sq, axes: [1, 2, 3], name: nil) // [B,1,1,1] + let count = graph.reductionSum(with: maskTensor, axes: [1, 2, 3], name: nil) // valid positions + let cTensor = graph.constant(numChannels.doubleValue, dataType: dataType) + let totalElts = graph.multiplication(count, cTensor, name: nil) + meanSq = graph.division(sumSq, totalElts, name: nil) + } else { + // Per-position normalization across channels. + let sumSq = graph.reductionSum(with: sq, axes: [1], name: nil) // [B,1,H,W] + let invC = graph.constant(1.0 / numChannels.doubleValue, dataType: dataType) + meanSq = graph.multiplication(sumSq, invC, name: nil) + } + + let epsTensor = graph.constant(Double(descriptor.epsilon), dataType: dataType) + let denom = graph.squareRoot(with: graph.addition(meanSq, epsTensor, name: nil), name: nil) + let normalized = graph.division(sourceTensor, denom, name: nil) + + let gammaShape: [NSNumber] = [1, numChannels, 1, 1] + let gammaTensor = graph.constant(Data(floatsNoCopy: descriptor.gamma!, shape: gammaShape), shape: gammaShape, dataType: dataType) + let betaTensor = graph.constant(Data(floatsNoCopy: descriptor.beta!, shape: gammaShape), shape: gammaShape, dataType: dataType) + let scaled = graph.addition(graph.multiplication(normalized, gammaTensor, name: nil), betaTensor, name: nil) + + let activated = ActivationLayer(graph: graph, sourceTensor: scaled, activationKind: activationKind).resultTensor + resultTensor = graph.multiplication(activated, maskTensor, name: nil) + } +} + +/// A transformer self-attention block (pre-norm, multi-head, optional 2D RoPE, GQA). +/// Mirrors the Eigen reference: RMSNorm -> Q/K/V projections -> RoPE -> scaled dot-product +/// attention with masked softmax -> output projection -> masked residual. +/// Tensors are NCHW [B, C, H, W]; spatial positions (H*W, ordered y*W+x) are the sequence. +struct TransformerAttentionBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerAttentionBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber + ) { + let dataType = sourceTensor.dataType + let numHeads = descriptor.numHeads + let numKVHeads = descriptor.numKVHeads + let qHeadDim = descriptor.qHeadDim + let vHeadDim = descriptor.vHeadDim + let nnX = nnXLen.intValue + let nnY = nnYLen.intValue + let seq = nnX * nnY + + // 1. RMSNorm (NCHW) + let normed = TransformerRMSNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preLN).resultTensor + + // To NHWC [B,H,W,C] so that reshape [-1, C] groups channels per position. + let normedNHWC = graph.transpose(normed, permutation: [0, 2, 3, 1], name: nil) + + // 2. Q/K/V projections via matmul over channels -> [B*seq, heads*dim] + let q = MatMulLayer(graph: graph, descriptor: descriptor.qProj, sourceTensor: normedNHWC).resultTensor + let k = MatMulLayer(graph: graph, descriptor: descriptor.kProj, sourceTensor: normedNHWC).resultTensor + let v = MatMulLayer(graph: graph, descriptor: descriptor.vProj, sourceTensor: normedNHWC).resultTensor + + // 3. reshape to [B, heads, seq, dim] + var qh = TransformerAttentionBlock.toHeads(graph, q, seq: seq, numHeads: numHeads, headDim: qHeadDim) + var kh = TransformerAttentionBlock.toHeads(graph, k, seq: seq, numHeads: numKVHeads, headDim: qHeadDim) + let vh = TransformerAttentionBlock.toHeads(graph, v, seq: seq, numHeads: numKVHeads, headDim: vHeadDim) + + // 4. RoPE on Q and K + if descriptor.useRope { + let numPairs = qHeadDim / 2 + // Q heads map to KV heads via kvh = h * numKVHeads / numHeads (matches Eigen). + let (qCos, qSin) = TransformerAttentionBlock.makeRopeTables( + graph, descriptor: descriptor, nHeads: numHeads, seq: seq, numPairs: numPairs, + nnX: nnX, nnY: nnY, qHeadDim: qHeadDim, dataType: dataType, + kvIndexForHead: { h in (h * numKVHeads) / numHeads }) + let (kCos, kSin) = TransformerAttentionBlock.makeRopeTables( + graph, descriptor: descriptor, nHeads: numKVHeads, seq: seq, numPairs: numPairs, + nnX: nnX, nnY: nnY, qHeadDim: qHeadDim, dataType: dataType, + kvIndexForHead: { h in h }) + qh = TransformerAttentionBlock.applyRope(graph, qh, cosT: qCos, sinT: qSin, + numHeads: numHeads, seq: seq, numPairs: numPairs) + kh = TransformerAttentionBlock.applyRope(graph, kh, cosT: kCos, sinT: kSin, + numHeads: numKVHeads, seq: seq, numPairs: numPairs) + } + + // GQA: if numKVHeads < numHeads, repeat KV heads so they align with query heads. + var khExp = kh + var vhExp = vh + if numKVHeads != numHeads { + let groupSize = numHeads / numKVHeads + khExp = TransformerAttentionBlock.repeatKVHeads(graph, kh, numKVHeads: numKVHeads, groupSize: groupSize, seq: seq, headDim: qHeadDim) + vhExp = TransformerAttentionBlock.repeatKVHeads(graph, vh, numKVHeads: numKVHeads, groupSize: groupSize, seq: seq, headDim: vHeadDim) + } + + // 5. scores = scale * Q @ K^T -> [B, heads, seq, seq] + let khT = graph.transpose(khExp, permutation: [0, 1, 3, 2], name: nil) + var scores = graph.matrixMultiplication(primary: qh, secondary: khT, name: nil) + let scale = graph.constant(1.0 / Double(qHeadDim).squareRoot(), dataType: dataType) + scores = graph.multiplication(scores, scale, name: nil) + + // Mask keys: add (maskKey - 1) * BIG so masked key columns get ~ -inf before softmax. + // maskTensor [B,1,H,W] -> [B,1,1,seq] + let maskNHWC = graph.transpose(maskTensor, permutation: [0, 2, 3, 1], name: nil) // [B,H,W,1] + let maskSeq = graph.reshape(maskNHWC, shape: [-1, 1, 1, seq as NSNumber], name: nil) + let one = graph.constant(1.0, dataType: dataType) + let big = graph.constant(1.0e9, dataType: dataType) + let keyBias = graph.multiplication(graph.subtraction(maskSeq, one, name: nil), big, name: nil) + scores = graph.addition(scores, keyBias, name: nil) + + // 6. softmax over key axis (last) + let attn = graph.softMax(with: scores, axis: 3, name: nil) + + // 7. out = attn @ V -> [B, heads, seq, vHeadDim] + let attnOut = graph.matrixMultiplication(primary: attn, secondary: vhExp, name: nil) + + // 8. back to [B*seq, heads*vHeadDim] + let outHeadsLast = graph.transpose(attnOut, permutation: [0, 2, 1, 3], name: nil) // [B,seq,heads,vHeadDim] + let outFlat = graph.reshape(outHeadsLast, shape: [-1, (numHeads * vHeadDim) as NSNumber], name: nil) + + // 9. output projection -> [B*seq, C] + let proj = MatMulLayer(graph: graph, descriptor: descriptor.outProj, sourceTensor: outFlat).resultTensor + + // 10. reshape to NHWC then NCHW + let outChannels = descriptor.outProj.outChannels + let projNHWC = graph.reshape(proj, shape: [-1, nnYLen, nnXLen, outChannels], name: nil) + let projNCHW = graph.transpose(projNHWC, permutation: [0, 3, 1, 2], name: nil) + + // 11. masked residual + let masked = graph.multiplication(projNCHW, maskTensor, name: nil) + resultTensor = graph.addition(sourceTensor, masked, name: nil) + } + + /// Reshape [B*seq, numHeads*headDim] -> [B, numHeads, seq, headDim]. + static func toHeads(_ graph: MPSGraph, _ x: MPSGraphTensor, seq: Int, numHeads: Int, headDim: Int) -> MPSGraphTensor { + let reshaped = graph.reshape(x, shape: [-1, seq as NSNumber, numHeads as NSNumber, headDim as NSNumber], name: nil) + return graph.transpose(reshaped, permutation: [0, 2, 1, 3], name: nil) + } + + /// Repeat each KV head groupSize times along the head axis: [B,numKVHeads,seq,dim] -> [B,numKVHeads*groupSize,seq,dim]. + static func repeatKVHeads(_ graph: MPSGraph, _ x: MPSGraphTensor, numKVHeads: Int, groupSize: Int, seq: Int, headDim: Int) -> MPSGraphTensor { + // Insert a group axis then broadcast: [B,kv,1,seq,dim] -> [B,kv,group,seq,dim] -> [B,kv*group,seq,dim] + let expanded = graph.reshape(x, shape: [-1, numKVHeads as NSNumber, 1, seq as NSNumber, headDim as NSNumber], name: nil) + let targetShape: [NSNumber] = [-1, numKVHeads as NSNumber, groupSize as NSNumber, seq as NSNumber, headDim as NSNumber] + let broadcast = graph.broadcast(expanded, shape: targetShape, name: nil) + return graph.reshape(broadcast, shape: [-1, (numKVHeads * groupSize) as NSNumber, seq as NSNumber, headDim as NSNumber], name: nil) + } + + /// Apply interleaved-pair RoPE to [B, nHeads, seq, headDim] using cos/sin tables [1,nHeads,seq,numPairs]. + static func applyRope(_ graph: MPSGraph, _ x: MPSGraphTensor, cosT: MPSGraphTensor, sinT: MPSGraphTensor, + numHeads: Int, seq: Int, numPairs: Int) -> MPSGraphTensor { + let pairsShape: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber, 2] + let xPairs = graph.reshape(x, shape: pairsShape, name: nil) + let evenShape: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber] + let xEven = graph.reshape(graph.sliceTensor(xPairs, dimension: 4, start: 0, length: 1, name: nil), shape: evenShape, name: nil) + let xOdd = graph.reshape(graph.sliceTensor(xPairs, dimension: 4, start: 1, length: 1, name: nil), shape: evenShape, name: nil) + let outEven = graph.subtraction(graph.multiplication(xEven, cosT, name: nil), graph.multiplication(xOdd, sinT, name: nil), name: nil) + let outOdd = graph.addition(graph.multiplication(xEven, sinT, name: nil), graph.multiplication(xOdd, cosT, name: nil), name: nil) + let pairShape5: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber, 1] + let outEvenE = graph.reshape(outEven, shape: pairShape5, name: nil) + let outOddE = graph.reshape(outOdd, shape: pairShape5, name: nil) + let stacked = graph.concatTensors([outEvenE, outOddE], dimension: 4, name: nil) + return graph.reshape(stacked, shape: [-1, numHeads as NSNumber, seq as NSNumber, (numPairs * 2) as NSNumber], name: nil) + } + + /// Build RoPE cos/sin constant tensors of shape [1, nHeads, seq, numPairs]. + static func makeRopeTables(_ graph: MPSGraph, descriptor: SWTransformerAttentionBlockDesc, + nHeads: Int, seq: Int, numPairs: Int, nnX: Int, nnY: Int, qHeadDim: Int, + dataType: MPSDataType, kvIndexForHead: (Int) -> Int) -> (MPSGraphTensor, MPSGraphTensor) { + let count = nHeads * seq * numPairs + let cosBuf = UnsafeMutablePointer.allocate(capacity: count) + let sinBuf = UnsafeMutablePointer.allocate(capacity: count) + let numPairsPerDim = numPairs / 2 + let dimHalf = qHeadDim / 2 + for h in 0.. SiLU(linear1)*gate -> linear2 -> masked residual. +struct TransformerFFNBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerFFNBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber + ) { + let numChannels = descriptor.numChannels + + // 1. RMSNorm + let normed = TransformerRMSNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preLN).resultTensor + let normedNHWC = graph.transpose(normed, permutation: [0, 2, 3, 1], name: nil) + + // 2. linear1 + gate, both [B*seq, ffnChannels] + let a = MatMulLayer(graph: graph, descriptor: descriptor.linear1, sourceTensor: normedNHWC).resultTensor + let gate = MatMulLayer(graph: graph, descriptor: descriptor.linearGate, sourceTensor: normedNHWC).resultTensor + + // 3. SwiGLU: SiLU(a) * gate, SiLU(a) = a * sigmoid(a) + let siluA = graph.multiplication(a, graph.sigmoid(with: a, name: nil), name: nil) + let h = graph.multiplication(siluA, gate, name: nil) + + // 4. linear2 -> [B*seq, numChannels] + let out = MatMulLayer(graph: graph, descriptor: descriptor.linear2, sourceTensor: h).resultTensor + + // 5. reshape to NHWC then NCHW, masked residual + let outNHWC = graph.reshape(out, shape: [-1, nnYLen, nnXLen, numChannels as NSNumber], name: nil) + let outNCHW = graph.transpose(outNHWC, permutation: [0, 3, 1, 2], name: nil) + let masked = graph.multiplication(outNCHW, maskTensor, name: nil) + resultTensor = graph.addition(sourceTensor, masked, name: nil) + } +} + // MARK: - Block Implementations /// A class that represents a Residual Block layer @@ -1241,6 +1750,26 @@ struct BlockStack { optimizeIdentityMask: optimizeIdentityMask) blockInput = ordinary.resultTensor + case let attnDescriptor as SWTransformerAttentionBlockDesc: + let attn = TransformerAttentionBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: attnDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + blockInput = attn.resultTensor + case let ffnDescriptor as SWTransformerFFNBlockDesc: + let ffn = TransformerFFNBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: ffnDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + blockInput = ffn.resultTensor default: blockInput = sourceTensor } @@ -1483,7 +2012,9 @@ public class SWTrunkDesc { let initialMatMul: SWMatMulLayerDesc let sgfMetadataEncoder: SWSGFMetadataEncoderDesc? let blockDescriptors: [BlockDescriptor] + let trunkNormKind: Int let trunkTipBN: SWBatchNormLayerDesc + let trunkTipRMSNorm: SWRMSNormLayerDesc let trunkTipActivation: ActivationKind init( @@ -1496,7 +2027,9 @@ public class SWTrunkDesc { initialMatMul: SWMatMulLayerDesc, sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, blockDescriptors: [BlockDescriptor], + trunkNormKind: Int, trunkTipBN: SWBatchNormLayerDesc, + trunkTipRMSNorm: SWRMSNormLayerDesc, trunkTipActivation: ActivationKind ) { self.version = version @@ -1508,7 +2041,9 @@ public class SWTrunkDesc { self.initialMatMul = initialMatMul self.sgfMetadataEncoder = sgfMetadataEncoder self.blockDescriptors = blockDescriptors + self.trunkNormKind = trunkNormKind self.trunkTipBN = trunkTipBN + self.trunkTipRMSNorm = trunkTipRMSNorm self.trunkTipActivation = trunkTipActivation } } @@ -1523,7 +2058,9 @@ public func createSWTrunkDesc( initialMatMul: SWMatMulLayerDesc, sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, blockDescriptors: [BlockDescriptor], + trunkNormKind: Int32, trunkTipBN: SWBatchNormLayerDesc, + trunkTipRMSNorm: SWRMSNormLayerDesc, trunkTipActivation: ActivationKind ) -> SWTrunkDesc { return SWTrunkDesc( @@ -1536,7 +2073,9 @@ public func createSWTrunkDesc( initialMatMul: initialMatMul, sgfMetadataEncoder: sgfMetadataEncoder, blockDescriptors: blockDescriptors, + trunkNormKind: Int(trunkNormKind), trunkTipBN: trunkTipBN, + trunkTipRMSNorm: trunkTipRMSNorm, trunkTipActivation: trunkTipActivation) } @@ -1632,21 +2171,34 @@ struct Trunk { nnYLen: nnYLen, optimizeIdentityMask: optimizeIdentityMask) - let trunkTipBN = BatchNormLayer( - graph: graph, - sourceTensor: blocks.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.trunkTipBN, - nnXLen: nnXLen, - nnYLen: nnYLen, - optimizeIdentityMask: optimizeIdentityMask) + // TRUNK_NORM_KIND_RMSNORM == 1: trunk tip uses RMSNorm with a fused activation. + // Otherwise (standard): BatchNorm followed by a separate activation. + if descriptor.trunkNormKind == 1 { + let trunkTipRMSNorm = TrunkRMSNormLayer( + graph: graph, + sourceTensor: blocks.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.trunkTipRMSNorm, + activationKind: descriptor.trunkTipActivation) - let trunkTipActivation = ActivationLayer( - graph: graph, - sourceTensor: trunkTipBN.resultTensor, - activationKind: descriptor.trunkTipActivation) + resultTensor = trunkTipRMSNorm.resultTensor + } else { + let trunkTipBN = BatchNormLayer( + graph: graph, + sourceTensor: blocks.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.trunkTipBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let trunkTipActivation = ActivationLayer( + graph: graph, + sourceTensor: trunkTipBN.resultTensor, + activationKind: descriptor.trunkTipActivation) - resultTensor = trunkTipActivation.resultTensor + resultTensor = trunkTipActivation.resultTensor + } assert(resultTensor.shape?.count == 4) } From 6f8314b7f6634afc0356a2ad5a1a061a48e2aaa9 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Mon, 1 Jun 2026 15:43:20 +0800 Subject: [PATCH 02/10] Fix Metal GPU crash on GQA transformer attention (NDArray INT_MAX) GQA models (numKVHeads != numHeads, e.g. b7c96h6kv3qk32v16tflrs) crashed on the Metal GPU path with: MPSNDArray.mm: NDArray dimension length > INT_MAX repeatKVHeads expanded the KV heads via reshape -> broadcast -> reshape, passing -1 for the batch dim in the broadcast target shape. Unlike reshape, MPSGraph.broadcast(_:shape:) does not infer -1 and treats it as a literal (near-INT_MAX) dimension, tripping the NDArray assertion. Replace the broadcast with a shape-safe slice + concat: slice each KV head (dim 1) and concatenate groupSize copies consecutively, so query head h uses kv = h / groupSize, matching the Eigen reference (kvh = h / kvGroupSize). No -1 broadcast. Verified: testgpuerror GPU vs Eigen reference at 9/13/19 now passes (~0.00003% winrate); non-GQA models (incl. b10c384h6) unaffected since the GQA branch is gated on numKVHeads != numHeads. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/metallayers.swift | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift index f275c7925..dc4d22d53 100644 --- a/cpp/neuralnet/metallayers.swift +++ b/cpp/neuralnet/metallayers.swift @@ -1401,11 +1401,19 @@ struct TransformerAttentionBlock { /// Repeat each KV head groupSize times along the head axis: [B,numKVHeads,seq,dim] -> [B,numKVHeads*groupSize,seq,dim]. static func repeatKVHeads(_ graph: MPSGraph, _ x: MPSGraphTensor, numKVHeads: Int, groupSize: Int, seq: Int, headDim: Int) -> MPSGraphTensor { - // Insert a group axis then broadcast: [B,kv,1,seq,dim] -> [B,kv,group,seq,dim] -> [B,kv*group,seq,dim] - let expanded = graph.reshape(x, shape: [-1, numKVHeads as NSNumber, 1, seq as NSNumber, headDim as NSNumber], name: nil) - let targetShape: [NSNumber] = [-1, numKVHeads as NSNumber, groupSize as NSNumber, seq as NSNumber, headDim as NSNumber] - let broadcast = graph.broadcast(expanded, shape: targetShape, name: nil) - return graph.reshape(broadcast, shape: [-1, (numKVHeads * groupSize) as NSNumber, seq as NSNumber, headDim as NSNumber], name: nil) + // Repeat each KV head groupSize times consecutively so query head h uses kv = h / groupSize, + // matching the Eigen reference (kvh = h / kvGroupSize). We slice each KV head and concat the + // copies along the head axis. Note: MPSGraph.broadcast(_:shape:) does NOT infer -1, so a + // reshape+broadcast approach with a dynamic batch dim triggers an NDArray INT_MAX assertion; + // slice+concat is shape-safe with no -1 broadcast. + var heads: [MPSGraphTensor] = [] + heads.reserveCapacity(numKVHeads * groupSize) + for kv in 0.. Date: Mon, 1 Jun 2026 15:43:33 +0800 Subject: [PATCH 03/10] Fix dropped SiLU activation in CoreML value/policy/meta heads The MIL builder's inline activation dispatch (buildValueHead v2, policy-head pass activation, and both SGF metadata encoder layers) handled only ReLU and Mish; SiLU silently fell through to the else branch and applied NO activation at all. This corrupted the value-head pool -> v2 -> v3 scalar path for every SiLU model, producing large errors in winrate/score/lead while ownership (which branches off v1, before v2) stayed correct. Add an ActivationType::Silu branch (addSiluOps) at all four sites. The generic conv/BN activation path already handled SiLU, which is why the trunk and v1/ownership were fine. Root-caused via systematic debugging: CoreML-CPU(fp32) error was identical to ANE (-> logical bug, not fp16), and perfect ownership with wrong scalars localized it to the value-head post-pooling path. This corrects the earlier "ANE is fp16-precision-limited (~5%)" conclusion -- that 5.66% on b10c384h6 was this bug. After the fix, testgpuerror ANE vs Eigen drops to GPU-level accuracy for all models: b10c384h6 5.66% -> ~0.00005-0.0002% cnorm 11-13% -> ~0.00007% rsnh 22-29% -> ~0.00004-0.0001% Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/external/katagocoreml/src/builder/MILBuilder.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index 2a0bbf44a..3305737cb 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -2593,6 +2593,8 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, setTensorOutput2D(op, pass_activated, ph.gpool_to_pass_mul.out_channels); } else if (ph.pass_activation->activation_type == ActivationType::Mish) { addMishOps(block, pass_biased, pass_activated, 2, ph.gpool_to_pass_mul.out_channels); + } else if (ph.pass_activation->activation_type == ActivationType::Silu) { + addSiluOps(block, pass_biased, pass_activated, 2, ph.gpool_to_pass_mul.out_channels); } else { pass_activated = pass_biased; } @@ -2640,6 +2642,8 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, setTensorOutput2D(op, v2, vh.v2_mul.out_channels); } else if (vh.v2_activation.activation_type == ActivationType::Mish) { addMishOps(block, v2_bias, v2, 2, vh.v2_mul.out_channels); + } else if (vh.v2_activation.activation_type == ActivationType::Silu) { + addSiluOps(block, v2_bias, v2, 2, vh.v2_mul.out_channels); } else { v2 = v2_bias; } @@ -2676,6 +2680,8 @@ std::string MILBuilder::buildSGFMetadataEncoder(CoreML::Specification::MILSpec:: setTensorOutput2D(op, act1, encoder.mul1.out_channels); } else if (encoder.act1.activation_type == ActivationType::Mish) { addMishOps(block, bias1, act1, 2, encoder.mul1.out_channels); + } else if (encoder.act1.activation_type == ActivationType::Silu) { + addSiluOps(block, bias1, act1, 2, encoder.mul1.out_channels); } else { // Identity activation - create identity op to preserve type information auto* op = block->add_operations(); @@ -2698,6 +2704,8 @@ std::string MILBuilder::buildSGFMetadataEncoder(CoreML::Specification::MILSpec:: setTensorOutput2D(op, act2, encoder.mul2.out_channels); } else if (encoder.act2.activation_type == ActivationType::Mish) { addMishOps(block, bias2, act2, 2, encoder.mul2.out_channels); + } else if (encoder.act2.activation_type == ActivationType::Silu) { + addSiluOps(block, bias2, act2, 2, encoder.mul2.out_channels); } else { // Identity activation - create identity op to preserve type information auto* op = block->add_operations(); From 792c4760628dbed29aa1d6d120494ff7643d7c00 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Mon, 1 Jun 2026 16:23:17 +0800 Subject: [PATCH 04/10] Implement GQA support in CoreML/ANE MIL attention builder The CoreML MIL builder threw "GQA (numKVHeads != numHeads) not supported" for grouped-query-attention transformer models, while the Metal GPU (MPSGraph) path already handled GQA. Port that support to the MIL builder. In buildTransformerAttentionBlock, remove the throw guard and, after the RoPE block and before the scores matmul, repeat each KV head groupSize (= numHeads/numKVHeads) times along the head axis via slice_by_size + concat (interleave=false), so query head h consumes kv head h/groupSize. This matches the Eigen reference (kvh = h/kvGroupSize) and the GPU repeatKVHeads ordering. RoPE stays before the repeat (its cos/sin tables are numKVHeads-shaped). The block is gated by numKVHeads != numHeads, so the standard MHA path is unchanged. Verified on b7c96h6kv3qk32v16tflrs-fson-bnh (6 query / 3 KV heads, qk32/v16) vs Eigen reference: ANE testgpuerror 9/13/19 = 0.00002-0.00003% winrate (previously a hard throw); GPU unchanged; non-GQA model ANE error identical to pre-change; runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 52 +++++++++++++++++-- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index 3305737cb..b90da55b6 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -1891,10 +1891,6 @@ std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MI const int qHeadDim = desc.q_head_dim, vHeadDim = desc.v_head_dim; const int qTotal = numHeads * qHeadDim, kTotal = numKVHeads * qHeadDim, vTotal = numKVHeads * vHeadDim; - if (numKVHeads != numHeads) { - throw std::runtime_error(desc.name + ": GQA (numKVHeads != numHeads) not supported in CoreML backend"); - } - auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, const std::vector& dims) { std::string shapeName = out + "_shape"; @@ -2024,6 +2020,54 @@ std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MI kh = applyRope(kh, numKVHeads, "k"); } + // GQA: when numKVHeads < numHeads, repeat each KV head groupSize times along the head + // axis (axis 1) so query head h consumes kv head (h / groupSize). RoPE has already been + // applied above to the unexpanded kh (kh = applyRope(kh, numKVHeads, "k")), mirroring the + // GPU path (metallayers.swift repeatKVHeads runs AFTER applyRope). We slice each KV head + // and concat its copies consecutively, so the resulting head index is kv*groupSize + g; + // query head h then maps to kv = h/groupSize == (h*numKVHeads)/numHeads (exact divisor, + // the same formula the qh RoPE table uses) == Eigen's kvh = h/kvGroupSize. slice_by_size + + // concat (not reshape+broadcast) avoids the dynamic -1 batch broadcast pitfall, same as the + // GPU code. The repeat is required so the scores (qh@kh^T) and attnOut (attn@vh) matmuls see + // matching [B,numHeads,...] batch dims instead of numHeads vs numKVHeads (no broadcast). + if (numKVHeads != numHeads) { + const int groupSize = numHeads / numKVHeads; + auto repeatKVHeads = [&](const std::string& x, const std::string& tag, int headDim) { + std::vector parts; + parts.reserve(static_cast(numKVHeads) * groupSize); + for (int kv = 0; kv < numKVHeads; kv++) { + for (int g = 0; g < groupSize; g++) { + std::string part = genVarName(prefix + "_" + tag + "_slc"); + std::string beginName = part + "_begin", sizeName = part + "_size"; + addIntArrayConstOp(block, beginName, {0, kv, 0, 0}); + addIntArrayConstOp(block, sizeName, {-1, 1, seq, headDim}); + auto* sop = block->add_operations(); + sop->set_type("slice_by_size"); + (*sop->mutable_inputs())["x"].add_arguments()->set_name(x); + (*sop->mutable_inputs())["begin"].add_arguments()->set_name(beginName); + (*sop->mutable_inputs())["size"].add_arguments()->set_name(sizeName); + setShape(sop, part, {-1, 1, seq, headDim}); + parts.push_back(part); + } + } + std::string out = genVarName(prefix + "_" + tag + "_exp"); + std::string axisName = out + "_axis", interleaveName = out + "_interleave"; + addIntScalarConstOp(block, axisName, 1); + addBoolScalarConstOp(block, interleaveName, false); + auto* cop = block->add_operations(); + cop->set_type("concat"); + auto& cin = *cop->mutable_inputs(); + for (const std::string& part : parts) + cin["values"].add_arguments()->set_name(part); + cin["axis"].add_arguments()->set_name(axisName); + cin["interleave"].add_arguments()->set_name(interleaveName); + setShape(cop, out, {-1, numHeads, seq, headDim}); + return out; + }; + kh = repeatKVHeads(kh, "khrep", qHeadDim); + vh = repeatKVHeads(vh, "vhrep", vHeadDim); + } + std::string scores = genVarName(prefix + "_scores"); matmul(qh, kh, scores, {-1, numHeads, seq, seq}, false, true); std::string scaleName = prefix + "_scale"; From 3839e529160e103d4120d879395c126dcb57ae0c Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Tue, 2 Jun 2026 07:07:55 +0800 Subject: [PATCH 05/10] Fix CoreML/ANE FP16 transformer accuracy via precision tiers Transformer models failed testgpuerror on the CoreML/ANE FP16 path: the ANE accumulates FP16 matmuls AND convs in FP16 (unlike OpenCL/CUDA/TRT, which accumulate in FP32), so wide/deep transformers lose too much precision and miss the thresholds at larger board sizes. BF16 is not an option (no compute path in CoreML: cast op, ArrayFeatureType and MLMultiArray all lack bf16; coremltools confirms FLOAT16/FLOAT32 only). Follow KataGo's FP16 convention (spatial convs FP16, non-spatial FP32), channel-gated for the ANE since every FP32 op runs off the FP16-only ANE: - RMSNorm reduction cores: FP32 in FP16 mode (always). - Non-spatial (FFN/Q-K-V proj/pooling/matmul): FP32 (always). MIL `linear` needs const weight/bias so it can't runtime-cast; only `matmul` is wrapped. - Convs: FP32 only for wide trunks (>= 320ch); narrower keep convs on-ANE. - Narrow trunks (< 256ch) sit on the testgpuerror thresholds and no partial FP32 config passes all board sizes (islands cast back to FP16 leave a noisy FP16 spatial stream); build them fully FP32 (off-ANE, cheap since small). Weights stay FP16-stored via runtime up-casts, except full-FP32 models. Add per-weight FP32 serialization (WeightEntry.is_fp32) so a const declared FP32 inside an otherwise-FP16 model is stored FP32 (fixes the load-time "storage and type have different number of elements" abort and enables the full-FP32 tier). Also fixes addFloatScalarConstOp keying storage off m_use_fp16 instead of the declared m_weight_dtype. Result: all 4 transformer test models (b10c384h6/b4c256h4/b7c96h3/ b7c96h6kv3-GQA) pass testgpuerror on ANE FP16 at sizes 9/13/19; runtests and runnnlayertests pass. All changes gated on m_use_fp16; FP32 mode unchanged. The 256/320 channel thresholds are width heuristics validated on these models. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 271 +++++++++++++++--- .../katagocoreml/src/builder/MILBuilder.hpp | 16 ++ .../katagocoreml/src/builder/Operations.cpp | 4 +- .../katagocoreml/src/builder/Operations.hpp | 7 +- .../src/serializer/WeightSerializer.cpp | 6 +- 5 files changed, 268 insertions(+), 36 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index b90da55b6..bef0fea73 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -31,7 +31,24 @@ MILBuilder::MILBuilder(const KataGoModelDesc& model, ? CoreML::Specification::MILSpec::DataType::FLOAT16 : CoreML::Specification::MILSpec::DataType::FLOAT32) , m_ops(board_x_size, board_y_size, optimize_identity_mask) - , m_var_counter(0) {} + , m_var_counter(0) { + // Precision tiers in FP16 mode (the ANE accumulates FP16 in FP16; FP32 ops run off the FP16-only + // ANE). NARROW transformer trunks are unreliable on the FP16 ANE: their policy/value metrics sit + // right on the testgpuerror thresholds and no partial-FP32 config passes all board sizes (partial + // FP32 leaves a noisy FP16 spatial stream). So build narrow trunks FULLY in FP32 (off-ANE, but + // cheap since narrow models are small; correct because it equals the FP32 reference). Weights are + // stored FP32 via per-weight serialization. Wider trunks use partial FP32: non-spatial (matmuls + + // pooling) always FP32; convs FP32 only for very wide trunks (kept on the ANE for narrower ones). + const int trunkChannels = model.trunk.trunk_num_channels; + const bool full_fp32 = use_fp16 && trunkChannels < FULL_FP32_MAX_TRUNK_CHANNELS; + if (full_fp32) { + m_use_fp16 = false; + m_use_fp16_io = false; + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + m_nonspatial_fp32 = m_use_fp16; + m_conv_fp32 = m_use_fp16 && trunkChannels >= CONV_FP32_MIN_TRUNK_CHANNELS; +} void MILBuilder::setBatchDimension(CoreML::Specification::MILSpec::TensorType* tensor_type) { auto* dim = tensor_type->add_dimensions(); @@ -213,8 +230,10 @@ void MILBuilder::addConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, const std::vector& data, const std::vector& shape) { - // Register weight for blob storage - m_ops.registerWeight(name, data, shape); + // Register weight for blob storage. Mark FP32 storage when this const is declared FP32 (e.g. + // inside an FP32 sub-region of an otherwise-FP16 model) so storage matches the declared type. + m_ops.registerWeight(name, data, shape, + m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT32); // Add const operation auto* op = block->add_operations(); @@ -329,7 +348,11 @@ void MILBuilder::addFloatScalarConstOp(CoreML::Specification::MILSpec::Block* bl val_type->set_datatype(m_weight_dtype); val_type->set_rank(0); - if (m_use_fp16) { + // Key the storage format off the DECLARED dtype (m_weight_dtype), not the global m_use_fp16: + // a temporarily-flipped FP32 sub-region (m_weight_dtype=FLOAT32 while m_use_fp16 stays true) + // must store FP32 floats, or CoreML rejects the model ("storage and type have different number + // of elements"). For all non-flipped calls m_weight_dtype tracks m_use_fp16, so this is a no-op. + if (m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT16) { // For FP16, use bytes storage with FP16 representation MILBlob::Fp16 fp16_val = MILBlob::Fp16::FromFloat(value); std::string bytes_data(reinterpret_cast(&fp16_val.bytes), sizeof(fp16_val.bytes)); @@ -427,6 +450,42 @@ void MILBuilder::addCastOp(CoreML::Specification::MILSpec::Block* block, } } +std::string MILBuilder::castFixed(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& dtype, + const std::vector& dims) { + std::string out = genVarName(input + "_cast"); + std::string dtName = out + "_dt"; + { + auto* op = block->add_operations(); + op->set_type("const"); + auto& na = (*op->mutable_attributes())["name"]; + na.mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + na.mutable_immediatevalue()->mutable_tensor()->mutable_strings()->add_values(dtName); + auto& va = (*op->mutable_attributes())["val"]; + va.mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + va.mutable_immediatevalue()->mutable_tensor()->mutable_strings()->add_values(dtype); + auto* o = op->add_outputs(); + o->set_name(dtName); + o->mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + } + auto* op = block->add_operations(); + op->set_type("cast"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(input); + (*op->mutable_inputs())["dtype"].add_arguments()->set_name(dtName); + auto* o = op->add_outputs(); + o->set_name(out); + auto* tt = o->mutable_type()->mutable_tensortype(); + tt->set_datatype(dtype == "fp32" ? CoreML::Specification::MILSpec::DataType::FLOAT32 + : CoreML::Specification::MILSpec::DataType::FLOAT16); + tt->set_rank(static_cast(dims.size())); + for (int64_t d : dims) { + if (d < 0) tt->add_dimensions()->mutable_unknown()->set_variadic(false); + else tt->add_dimensions()->mutable_constant()->set_size(d); + } + return out; +} + void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, @@ -567,6 +626,21 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, tt->add_dimensions()->mutable_constant()->set_size(4); } + // Channel-gated FP32 convs. The ANE accumulates FP16 convs in FP16, which loses too much + // precision for WIDE trunks and fails testgpuerror at large board sizes (validated: 384ch + // fails, <=256ch is fine FP16-on-ANE). For wide trunks (>= threshold) run convs in FP32 (weights + // cast up at runtime, stored fp16). FP32 convs can't run on the fp16-only ANE, so only the wide + // models that actually need it pay that off-ANE cost; narrow models keep convs on the ANE. + const bool convFp32 = m_conv_fp32; + std::string convX = input, convW = weight_name, convOut = output; + auto savedConvDtype = m_weight_dtype; + if (convFp32) { + convX = castFixed(block, input, "fp32", {-1, layer.in_channels, m_board_y_size, m_board_x_size}); + convW = castFixed(block, weight_name, "fp32", layer.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + convOut = output + "_cf32"; + } + // Add conv operation referencing all const parameters auto* op = block->add_operations(); op->set_type("conv"); @@ -578,12 +652,12 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, inputs["pad"].add_arguments()->set_name(pad_name); inputs["pad_type"].add_arguments()->set_name(pad_type_name); inputs["strides"].add_arguments()->set_name(strides_name); - inputs["weight"].add_arguments()->set_name(weight_name); - inputs["x"].add_arguments()->set_name(input); + inputs["weight"].add_arguments()->set_name(convW); + inputs["x"].add_arguments()->set_name(convX); // Output with dimensions [batch, out_channels, height, width] auto* out = op->add_outputs(); - out->set_name(output); + out->set_name(convOut); auto* out_type = out->mutable_type()->mutable_tensortype(); out_type->set_datatype(m_weight_dtype); out_type->set_rank(4); @@ -591,6 +665,11 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, out_type->add_dimensions()->mutable_constant()->set_size(layer.out_channels); out_type->add_dimensions()->mutable_constant()->set_size(m_board_y_size); out_type->add_dimensions()->mutable_constant()->set_size(m_board_x_size); + + if (convFp32) { + m_weight_dtype = savedConvDtype; + addCastOp(block, convOut, output, "fp16", {-1, layer.out_channels, m_board_y_size, m_board_x_size}); + } } // Helper: Set output tensor type with 4D shape [batch, C, H, W] @@ -945,23 +1024,38 @@ void MILBuilder::addMatMulOp(CoreML::Specification::MILSpec::Block* block, CoreML::Specification::MILSpec::DataType::BOOL); } + // Non-spatial matmul in FP32 (KataGo FP16 convention; weights cast up at runtime, stored fp16). + std::string mmIn = input, mmW = weight_name, mmOut = output; + auto savedMmDtype = m_weight_dtype; + if (m_nonspatial_fp32) { + mmIn = castFixed(block, input, "fp32", {-1, layer.in_channels}); + mmW = castFixed(block, weight_name, "fp32", layer.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + mmOut = output + "_mmf32"; + } + // Add matmul operation auto* op = block->add_operations(); op->set_type("matmul"); auto& inputs = *op->mutable_inputs(); inputs["transpose_x"].add_arguments()->set_name(transpose_x_name); inputs["transpose_y"].add_arguments()->set_name(transpose_y_name); - inputs["x"].add_arguments()->set_name(input); - inputs["y"].add_arguments()->set_name(weight_name); + inputs["x"].add_arguments()->set_name(mmIn); + inputs["y"].add_arguments()->set_name(mmW); // Output with 2D shape [batch, out_channels] auto* out = op->add_outputs(); - out->set_name(output); + out->set_name(mmOut); auto* out_type = out->mutable_type()->mutable_tensortype(); out_type->set_datatype(m_weight_dtype); out_type->set_rank(2); setBatchDimension(out_type); out_type->add_dimensions()->mutable_constant()->set_size(layer.out_channels); + + if (m_nonspatial_fp32) { + m_weight_dtype = savedMmDtype; + addCastOp(block, mmOut, output, "fp16", {-1, layer.out_channels}); + } } void MILBuilder::addMatBiasOp(CoreML::Specification::MILSpec::Block* block, @@ -1022,7 +1116,9 @@ void MILBuilder::addLinearOp(CoreML::Specification::MILSpec::Block* block, std::vector bias_shape = {static_cast(bias.num_channels)}; addConstOp(block, bias_name, bias.weights, bias_shape); - // Add linear operation + // NOTE: the MIL `linear` op requires const weight/bias, so the runtime-cast-to-FP32 trick can't + // be applied here (unlike `matmul`). Value-head linear stays FP16; if a model ever needs it in + // FP32, rewrite as matmul+add (matmul accepts cast inputs). auto* op = block->add_operations(); op->set_type("linear"); auto& inputs = *op->mutable_inputs(); @@ -1718,8 +1814,20 @@ std::string MILBuilder::addTransformerRMSNorm(CoreML::Specification::MILSpec::Bl setShape(op, out, dims); }; + // RMSNorm reduction core: square -> mean over channels -> rsqrt. In FP16 mode compute this + // core in FP32 (cast input up, flip the working dtype so the core's op outputs + eps scalar are + // FP32, then cast 1/rms back down). The FP16 channel reduction loses too much precision on the + // ANE; only this core is FP32 - the scaling/weight/mask below stay FP16. No addConstOp lives in + // the flipped window, so weight serialization is unaffected. + auto savedDtype = m_weight_dtype; + std::string sqSrc = input; + if (m_use_fp16) { + sqSrc = genVarName(prefix + "_in32"); + addCastOp(block, input, sqSrc, "fp32", {-1, C, H, W}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } std::string sq = genVarName(prefix + "_sq"); - emit2("mul", input, input, sq, {-1, C, H, W}); + emit2("mul", sqSrc, sqSrc, sq, {-1, C, H, W}); // meanSq = reduce_mean(sq, axes=[1]) over channels. reduce_mean (not reduce_sum) is used so // the accumulator stays ~O(activation^2) instead of summing hundreds of channels, which can // overflow FP16 (and the FP16 accumulation on ANE) for large activations. @@ -1739,13 +1847,19 @@ std::string MILBuilder::addTransformerRMSNorm(CoreML::Specification::MILSpec::Bl // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. std::string epsName = prefix + "_eps"; addFloatScalarConstOp(block, epsName, desc.epsilon); - std::string inv = genVarName(prefix + "_inv"); + std::string invCore = genVarName(prefix + "_inv"); { auto* op = block->add_operations(); op->set_type("rsqrt"); (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); - setShape(op, inv, {-1, 1, H, W}); + setShape(op, invCore, {-1, 1, H, W}); + } + std::string inv = invCore; + if (m_use_fp16) { + m_weight_dtype = savedDtype; + inv = genVarName(prefix + "_inv16"); + addCastOp(block, invCore, inv, "fp16", {-1, 1, H, W}); } std::string normalized = genVarName(prefix + "_norm"); emit2("mul", input, inv, normalized, {-1, C, H, W}); @@ -1788,8 +1902,22 @@ std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* b setShape(op, out, dims); }; + // Variance core (mask -> square -> reduce -> rsqrt) in FP32 when in FP16 mode. The trunk-tip + // norm in particular reduces over many elements and loses too much precision in FP16 on the + // ANE; compute the core in FP32 and cast 1/rms back to FP16. Only the core is FP32 - gamma/beta, + // the activation and the final mask below stay FP16. No addConstOp lives in the flipped window. + auto savedDtype = m_weight_dtype; + std::string tinput = input; + std::string tmask = mask; + if (m_use_fp16) { + tinput = genVarName(prefix + "_in32"); + addCastOp(block, input, tinput, "fp32", {-1, C, H, W}); + tmask = genVarName(prefix + "_mask32"); + addCastOp(block, mask, tmask, "fp32", {-1, 1, H, W}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } std::string masked = genVarName(prefix + "_premask"); - emit2("mul", input, mask, masked, {-1, C, H, W}); + emit2("mul", tinput, tmask, masked, {-1, C, H, W}); std::string sq = genVarName(prefix + "_sq"); emit2("mul", masked, masked, sq, {-1, C, H, W}); @@ -1813,7 +1941,7 @@ std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* b setShape(op, meanAll, {-1, 1, 1, 1}); } std::string count = genVarName(prefix + "_count"); - reduceSum(mask, count, {1, 2, 3}, {-1, 1, 1, 1}); // valid positions (<= H*W, no overflow) + reduceSum(tmask, count, {1, 2, 3}, {-1, 1, 1, 1}); // valid positions (<= H*W, no overflow) std::string totalPosName = prefix + "_totalpos"; addFloatScalarConstOp(block, totalPosName, static_cast(H * W)); std::string scaleF = genVarName(prefix + "_scalef"); @@ -1839,13 +1967,19 @@ std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* b // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. std::string epsName = prefix + "_eps"; addFloatScalarConstOp(block, epsName, desc.epsilon); - std::string inv = genVarName(prefix + "_inv"); + std::string invCore = genVarName(prefix + "_inv"); { auto* op = block->add_operations(); op->set_type("rsqrt"); (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); - setShape(op, inv, denomDims); + setShape(op, invCore, denomDims); + } + std::string inv = invCore; + if (m_use_fp16) { + m_weight_dtype = savedDtype; + inv = genVarName(prefix + "_inv16"); + addCastOp(block, invCore, inv, "fp16", denomDims); } std::string normalized = genVarName(prefix + "_norm"); emit2("mul", input, inv, normalized, {-1, C, H, W}); @@ -1938,11 +2072,25 @@ std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MI transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); std::string x2d = genVarName(prefix + "_x2d"); reshape(nhwc, x2d, {-1, C}, {-1, C}); + // Q/K/V projection matmuls in FP32 (non-spatial, per KataGo's FP16 convention): they reduce over + // C channels and the ANE's FP16 accumulation loses too much precision for wide models. Weights + // stay fp16-stored (cast up at runtime); output cast back to FP16 for the FP16 head reshapes. auto proj = [&](const MatMulLayerDesc& w, const std::string& nm, int total) { std::string wName = nm + "_w"; addConstOp(block, wName, w.weights, w.getWeightShape()); std::string out = genVarName(nm); - matmul(x2d, wName, out, {-1, total}, false, false); + if (m_nonspatial_fp32) { + std::string x32 = castFixed(block, x2d, "fp32", {-1, C}); + std::string w32 = castFixed(block, wName, "fp32", w.getWeightShape()); + auto sd = m_weight_dtype; + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + std::string o32 = genVarName(nm + "_f32"); + matmul(x32, w32, o32, {-1, total}, false, false); + m_weight_dtype = sd; + out = castFixed(block, o32, "fp16", {-1, total}); + } else { + matmul(x2d, wName, out, {-1, total}, false, false); + } return out; }; std::string q2d = proj(desc.q_proj, prefix + "_q", qTotal); @@ -2218,14 +2366,29 @@ std::string MILBuilder::buildTransformerFFNBlock(CoreML::Specification::MILSpec: std::string x2d = genVarName(prefix + "_x2d"); reshape(nhwc, x2d, {-1, C}, {-1, C}); + // FFN matmuls in FP32 (weights cast up at runtime, stored fp16) — KataGo's FP16 convention is + // spatial(convs)=FP16, non-spatial(matmuls)=FP32 (see openclbackend.cpp). The ANE accumulates + // FP16 matmuls in FP16, which loses too much precision over C/ffn; run them in FP32 instead. std::string w1 = prefix + "_w1"; addConstOp(block, w1, desc.linear1.weights, desc.linear1.getWeightShape()); - std::string a = genVarName(prefix + "_a"); - matmul(x2d, w1, a, {-1, ffn}); std::string wg = prefix + "_wg"; addConstOp(block, wg, desc.linear_gate.weights, desc.linear_gate.getWeightShape()); + std::string w2 = prefix + "_w2"; + addConstOp(block, w2, desc.linear2.weights, desc.linear2.getWeightShape()); + + auto savedDtype = m_weight_dtype; + std::string mx2d = x2d, mw1 = w1, mwg = wg, mw2 = w2; + if (m_nonspatial_fp32) { + mx2d = castFixed(block, x2d, "fp32", {-1, C}); + mw1 = castFixed(block, w1, "fp32", desc.linear1.getWeightShape()); + mwg = castFixed(block, wg, "fp32", desc.linear_gate.getWeightShape()); + mw2 = castFixed(block, w2, "fp32", desc.linear2.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + std::string a = genVarName(prefix + "_a"); + matmul(mx2d, mw1, a, {-1, ffn}); std::string g = genVarName(prefix + "_g"); - matmul(x2d, wg, g, {-1, ffn}); + matmul(mx2d, mwg, g, {-1, ffn}); std::string sig = genVarName(prefix + "_sig"); { @@ -2239,10 +2402,13 @@ std::string MILBuilder::buildTransformerFFNBlock(CoreML::Specification::MILSpec: std::string h = genVarName(prefix + "_h"); binary("mul", siluA, g, h, {-1, ffn}); - std::string w2 = prefix + "_w2"; - addConstOp(block, w2, desc.linear2.weights, desc.linear2.getWeightShape()); - std::string o = genVarName(prefix + "_o"); - matmul(h, w2, o, {-1, C}); + std::string oCore = genVarName(prefix + "_o"); + matmul(h, mw2, oCore, {-1, C}); + std::string o = oCore; + if (m_nonspatial_fp32) { + m_weight_dtype = savedDtype; + o = castFixed(block, oCore, "fp16", {-1, C}); + } std::string oNHWC = genVarName(prefix + "_onhwc"); reshape(o, oNHWC, {-1, H, W, C}, {-1, H, W, C}); @@ -2443,9 +2609,25 @@ std::string MILBuilder::buildGlobalPoolingResidualBlock(CoreML::Specification::M std::string gpool_bn_out = genVarName(prefix + "_gpool_bn"); addBatchNormActivationOps(block, gpool_conv_out, block_desc.gpool_bn, block_desc.gpool_activation, mask, gpool_bn_out); - // Global pooling + // Global pooling. Non-spatial per KataGo's FP16 convention -> FP32 (openclbackend.cpp: pooling + // an FP16 tensor produces FP32 pooled values). The spatial sum over H*W loses too much precision + // in FP16 at larger board sizes, corrupting the bias fed back into the whole trunk. No + // addConstOp in the pooling -> flipping m_weight_dtype is safe. std::string gpool_features = genVarName(prefix + "_gpool_features"); - addGlobalPoolingOps(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features); + if (m_nonspatial_fp32) { + auto savedDtype = m_weight_dtype; + std::string gpIn32 = castFixed(block, gpool_bn_out, "fp32", {-1, block_desc.gpool_conv.out_channels, m_board_y_size, m_board_x_size}); + std::string gpMask = mask; + if (!m_optimize_identity_mask) + gpMask = castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + std::string gpOut32 = genVarName(prefix + "_gpool_features_f32"); + addGlobalPoolingOps(block, gpIn32, gpMask, block_desc.gpool_conv.out_channels, gpOut32); + m_weight_dtype = savedDtype; + gpool_features = castFixed(block, gpOut32, "fp16", {-1, block_desc.gpool_conv.out_channels * 3}); + } else { + addGlobalPoolingOps(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features); + } // Project to bias std::string gpool_bias = genVarName(prefix + "_gpool_bias"); @@ -2577,9 +2759,22 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, std::string g1 = genVarName("policy_g1"); addBatchNormActivationOps(block, g1_conv, ph.g1_bn, ph.g1_activation, mask, g1); - // Global pooling on G1 + // Global pooling on G1 — non-spatial per KataGo's FP16 convention -> FP32 (the FP16 spatial sum + // loses precision; feeds the policy bias, affecting policyKLDiv). No addConstOp in pooling. std::string g1_pooled = genVarName("policy_g1_pool"); - addGlobalPoolingOps(block, g1, mask, ph.g1_conv.out_channels, g1_pooled); + if (m_nonspatial_fp32) { + auto savedDtype = m_weight_dtype; + std::string gpIn32 = castFixed(block, g1, "fp32", {-1, ph.g1_conv.out_channels, m_board_y_size, m_board_x_size}); + std::string gpMask = m_optimize_identity_mask ? mask + : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + std::string gpOut32 = genVarName("policy_g1_pool_f32"); + addGlobalPoolingOps(block, gpIn32, gpMask, ph.g1_conv.out_channels, gpOut32); + m_weight_dtype = savedDtype; + addCastOp(block, gpOut32, g1_pooled, "fp16", {-1, ph.g1_conv.out_channels * 3}); + } else { + addGlobalPoolingOps(block, g1, mask, ph.g1_conv.out_channels, g1_pooled); + } // Project to spatial bias std::string gpool_bias = genVarName("policy_gpool_bias"); @@ -2669,9 +2864,21 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, std::string v1 = genVarName("value_v1"); addBatchNormActivationOps(block, v1_conv, vh.v1_bn, vh.v1_activation, mask, v1); - // Global pooling (value head version) + // Global pooling (value head version) — non-spatial -> FP32 (KataGo FP16 convention). std::string v1_pooled = genVarName("value_v1_pool"); - addGlobalPoolingValueOps(block, v1, mask, vh.v1_conv.out_channels, v1_pooled); + if (m_nonspatial_fp32) { + auto savedDtype = m_weight_dtype; + std::string vpIn32 = castFixed(block, v1, "fp32", {-1, vh.v1_conv.out_channels, m_board_y_size, m_board_x_size}); + std::string vpMask = m_optimize_identity_mask ? mask + : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + std::string vpOut32 = genVarName("value_v1_pool_f32"); + addGlobalPoolingValueOps(block, vpIn32, vpMask, vh.v1_conv.out_channels, vpOut32); + m_weight_dtype = savedDtype; + addCastOp(block, vpOut32, v1_pooled, "fp16", {-1, vh.v1_conv.out_channels * 3}); + } else { + addGlobalPoolingValueOps(block, v1, mask, vh.v1_conv.out_channels, v1_pooled); + } // V2: linear + activation (fused matmul+bias -> linear) std::string v2_bias = genVarName("value_v2_bias"); diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index 5d25b963a..ad67f150f 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -43,6 +43,14 @@ class MILBuilder { bool m_optimize_identity_mask; bool m_use_fp16; bool m_use_fp16_io; + // FP32 in FP16 mode follows KataGo's FP16 convention (spatial convs FP16, non-spatial FP32), but + // FP32 ops run off the FP16-only ANE, so convs are channel-gated to only the wide trunks that + // need it. RMSNorm reductions: always FP32 (cheap, needed by all). Non-spatial matmuls+pooling: + // always FP32 (every width needs it at some board size). Convs: FP32 only for wide trunks. + static constexpr int CONV_FP32_MIN_TRUNK_CHANNELS = 320; // convs run FP32 at/above this width + static constexpr int FULL_FP32_MAX_TRUNK_CHANNELS = 256; // trunks below this build fully FP32 + bool m_nonspatial_fp32 = false; // = m_use_fp16 (matmuls + global pooling) + bool m_conv_fp32 = false; // = m_use_fp16 && trunk_channels >= CONV_FP32_MIN_... int m_min_batch_size; int m_max_batch_size; CoreML::Specification::MILSpec::DataType m_weight_dtype; @@ -102,6 +110,14 @@ class MILBuilder { const std::string& dtype, const std::vector& shape); + // Cast to a tensor with FULLY-specified dims (no forced batch dim like addCastOp). Use for + // weight tensors (fixed [in,out] dims) when running an otherwise-FP16 op in FP32. Returns the + // new tensor name. dims use -1 for an unknown/batch dim, >=0 for a constant dim. + std::string castFixed(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& dtype, + const std::vector& dims); + void addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, diff --git a/cpp/external/katagocoreml/src/builder/Operations.cpp b/cpp/external/katagocoreml/src/builder/Operations.cpp index c0c036292..1c625acdd 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.cpp +++ b/cpp/external/katagocoreml/src/builder/Operations.cpp @@ -14,12 +14,14 @@ KataGoOps::KataGoOps(int board_x_size, int board_y_size, bool optimize_identity_ std::string KataGoOps::registerWeight(const std::string& name, const std::vector& data, - const std::vector& shape) { + const std::vector& shape, + bool is_fp32) { WeightEntry entry; entry.name = name; entry.data = data; entry.shape = shape; entry.blob_offset = 0; // Will be set during serialization + entry.is_fp32 = is_fp32; m_weights.push_back(std::move(entry)); return name; } diff --git a/cpp/external/katagocoreml/src/builder/Operations.hpp b/cpp/external/katagocoreml/src/builder/Operations.hpp index 3fc72ad88..a9d2a1466 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.hpp +++ b/cpp/external/katagocoreml/src/builder/Operations.hpp @@ -16,6 +16,8 @@ struct WeightEntry { std::vector data; std::vector shape; uint64_t blob_offset = 0; // Set during serialization + bool is_fp32 = false; // Store as FP32 (set when the const was declared FP32, e.g. inside an + // FP32 sub-region of an otherwise-FP16 model). Else stored per global mode. }; /// Precomputed constants for identity mask optimization @@ -51,10 +53,11 @@ class KataGoOps { /// Get precomputed mask constants const MaskConstants& getMaskConstants() const { return m_mask_constants; } - /// Register a weight tensor and return its reference name + /// Register a weight tensor and return its reference name. is_fp32 marks it for FP32 storage. std::string registerWeight(const std::string& name, const std::vector& data, - const std::vector& shape); + const std::vector& shape, + bool is_fp32 = false); /// Get all registered weights const std::vector& getWeights() const { return m_weights; } diff --git a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp index 2ac23a3da..e8fe861c8 100644 --- a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp +++ b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp @@ -15,7 +15,11 @@ size_t WeightSerializer::serialize(std::vector& weights, size_t total_bytes = 0; for (auto& entry : weights) { - if (use_fp16) { + // Per-weight precision: store FP16 only when the global mode is FP16 AND this weight was not + // declared FP32 (entry.is_fp32 marks consts inside an FP32 sub-region of an FP16 model), so + // stored bytes stay consistent with each const's declared dtype. + const bool store_fp16 = use_fp16 && !entry.is_fp32; + if (store_fp16) { // Convert FP32 weights to FP16 std::vector fp16_data(entry.data.size()); for (size_t i = 0; i < entry.data.size(); ++i) { From 3eb81ce66a410e0515ba2d37787789c4e8539dd3 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Tue, 2 Jun 2026 07:40:26 +0800 Subject: [PATCH 06/10] Refactor: dedupe CoreML global-pooling FP32 wrap Three near-identical blocks wrapped global pooling in FP32 (policy head, value head, gpool residual block): cast input/mask up to FP32, flip m_weight_dtype, pool, restore, cast pooled features back to FP16 - with inconsistent save-variable names and one site using castFixed vs addCastOp for the output cast. Extract a single addGlobalPoolingFp32(input, mask, channels, output, valueVariant) helper and a small RAII ScopedFp32 guard for the temporary m_weight_dtype flip. The three call sites become one-liners. Behavior-preserving: same emitted op sequence; testgpuerror output is byte-identical across all precision tiers (partial-FP32 b10c384h6, full-FP32 b7c96h3, non-spatial-FP32 b4c256h4), all 12 transformer gate runs pass, runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 95 ++++++++++--------- .../katagocoreml/src/builder/MILBuilder.hpp | 9 ++ 2 files changed, 57 insertions(+), 47 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index bef0fea73..b4974aa39 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -11,6 +11,20 @@ namespace katagocoreml { +namespace { +// RAII: set a dtype slot to FLOAT32 for the current scope and restore it on exit. Used to emit a +// sub-region of ops in FP32 inside an otherwise-FP16 model. +struct ScopedFp32 { + CoreML::Specification::MILSpec::DataType& slot; + CoreML::Specification::MILSpec::DataType saved; + explicit ScopedFp32(CoreML::Specification::MILSpec::DataType& s) + : slot(s), saved(s) { s = CoreML::Specification::MILSpec::DataType::FLOAT32; } + ~ScopedFp32() { slot = saved; } + ScopedFp32(const ScopedFp32&) = delete; + ScopedFp32& operator=(const ScopedFp32&) = delete; +}; +} // namespace + MILBuilder::MILBuilder(const KataGoModelDesc& model, int board_x_size, int board_y_size, @@ -486,6 +500,32 @@ std::string MILBuilder::castFixed(CoreML::Specification::MILSpec::Block* block, return out; } +void MILBuilder::addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& mask, + int channels, + const std::string& output, + bool valueVariant) { + auto pool = [&](const std::string& in, const std::string& msk, const std::string& out) { + if (valueVariant) addGlobalPoolingValueOps(block, in, msk, channels, out); + else addGlobalPoolingOps(block, in, msk, channels, out); + }; + // Non-spatial per KataGo's FP16 convention -> FP32 (the FP16 spatial sum over H*W loses too much + // precision at larger board sizes). No addConstOp in the pooling, so flipping m_weight_dtype is + // safe. Cast input/mask up, pool in FP32, cast the [N, channels*3] features back to FP16. + if (!m_nonspatial_fp32) { + pool(input, mask, output); + return; + } + std::string in32 = castFixed(block, input, "fp32", {-1, channels, m_board_y_size, m_board_x_size}); + std::string mask32 = m_optimize_identity_mask + ? mask + : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + std::string out32 = genVarName(output + "_f32"); + { ScopedFp32 g(m_weight_dtype); pool(in32, mask32, out32); } + addCastOp(block, out32, output, "fp16", {-1, channels * 3}); +} + void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, @@ -2609,25 +2649,11 @@ std::string MILBuilder::buildGlobalPoolingResidualBlock(CoreML::Specification::M std::string gpool_bn_out = genVarName(prefix + "_gpool_bn"); addBatchNormActivationOps(block, gpool_conv_out, block_desc.gpool_bn, block_desc.gpool_activation, mask, gpool_bn_out); - // Global pooling. Non-spatial per KataGo's FP16 convention -> FP32 (openclbackend.cpp: pooling - // an FP16 tensor produces FP32 pooled values). The spatial sum over H*W loses too much precision - // in FP16 at larger board sizes, corrupting the bias fed back into the whole trunk. No - // addConstOp in the pooling -> flipping m_weight_dtype is safe. + // Global pooling (FP32 when m_nonspatial_fp32 -- see addGlobalPoolingFp32). Feeds a bias back + // into the whole trunk, so the FP16 spatial sum must not lose precision for wide trunks. std::string gpool_features = genVarName(prefix + "_gpool_features"); - if (m_nonspatial_fp32) { - auto savedDtype = m_weight_dtype; - std::string gpIn32 = castFixed(block, gpool_bn_out, "fp32", {-1, block_desc.gpool_conv.out_channels, m_board_y_size, m_board_x_size}); - std::string gpMask = mask; - if (!m_optimize_identity_mask) - gpMask = castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); - m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; - std::string gpOut32 = genVarName(prefix + "_gpool_features_f32"); - addGlobalPoolingOps(block, gpIn32, gpMask, block_desc.gpool_conv.out_channels, gpOut32); - m_weight_dtype = savedDtype; - gpool_features = castFixed(block, gpOut32, "fp16", {-1, block_desc.gpool_conv.out_channels * 3}); - } else { - addGlobalPoolingOps(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features); - } + addGlobalPoolingFp32(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features, + /*valueVariant=*/false); // Project to bias std::string gpool_bias = genVarName(prefix + "_gpool_bias"); @@ -2759,22 +2785,9 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, std::string g1 = genVarName("policy_g1"); addBatchNormActivationOps(block, g1_conv, ph.g1_bn, ph.g1_activation, mask, g1); - // Global pooling on G1 — non-spatial per KataGo's FP16 convention -> FP32 (the FP16 spatial sum - // loses precision; feeds the policy bias, affecting policyKLDiv). No addConstOp in pooling. + // Global pooling on G1 (FP32 when m_nonspatial_fp32; feeds the policy bias / policyKLDiv). std::string g1_pooled = genVarName("policy_g1_pool"); - if (m_nonspatial_fp32) { - auto savedDtype = m_weight_dtype; - std::string gpIn32 = castFixed(block, g1, "fp32", {-1, ph.g1_conv.out_channels, m_board_y_size, m_board_x_size}); - std::string gpMask = m_optimize_identity_mask ? mask - : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); - m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; - std::string gpOut32 = genVarName("policy_g1_pool_f32"); - addGlobalPoolingOps(block, gpIn32, gpMask, ph.g1_conv.out_channels, gpOut32); - m_weight_dtype = savedDtype; - addCastOp(block, gpOut32, g1_pooled, "fp16", {-1, ph.g1_conv.out_channels * 3}); - } else { - addGlobalPoolingOps(block, g1, mask, ph.g1_conv.out_channels, g1_pooled); - } + addGlobalPoolingFp32(block, g1, mask, ph.g1_conv.out_channels, g1_pooled, /*valueVariant=*/false); // Project to spatial bias std::string gpool_bias = genVarName("policy_gpool_bias"); @@ -2864,21 +2877,9 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, std::string v1 = genVarName("value_v1"); addBatchNormActivationOps(block, v1_conv, vh.v1_bn, vh.v1_activation, mask, v1); - // Global pooling (value head version) — non-spatial -> FP32 (KataGo FP16 convention). + // Global pooling (value head version; FP32 when m_nonspatial_fp32). std::string v1_pooled = genVarName("value_v1_pool"); - if (m_nonspatial_fp32) { - auto savedDtype = m_weight_dtype; - std::string vpIn32 = castFixed(block, v1, "fp32", {-1, vh.v1_conv.out_channels, m_board_y_size, m_board_x_size}); - std::string vpMask = m_optimize_identity_mask ? mask - : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); - m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; - std::string vpOut32 = genVarName("value_v1_pool_f32"); - addGlobalPoolingValueOps(block, vpIn32, vpMask, vh.v1_conv.out_channels, vpOut32); - m_weight_dtype = savedDtype; - addCastOp(block, vpOut32, v1_pooled, "fp16", {-1, vh.v1_conv.out_channels * 3}); - } else { - addGlobalPoolingValueOps(block, v1, mask, vh.v1_conv.out_channels, v1_pooled); - } + addGlobalPoolingFp32(block, v1, mask, vh.v1_conv.out_channels, v1_pooled, /*valueVariant=*/true); // V2: linear + activation (fused matmul+bias -> linear) std::string v2_bias = genVarName("value_v2_bias"); diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index ad67f150f..fe63b442f 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -118,6 +118,15 @@ class MILBuilder { const std::string& dtype, const std::vector& dims); + // Emit global pooling, running it in FP32 when m_nonspatial_fp32 (cast input/mask up, pool, + // cast the pooled features back to FP16). valueVariant selects the value-head pooling variant. + void addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& mask, + int channels, + const std::string& output, + bool valueVariant); + void addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, From d052d2a1be587360e89f9516673fcb6cde4b707c Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:44:59 +0800 Subject: [PATCH 07/10] Fix CoreML/ANE convnet regression: scope FP32 tiers to transformers The width-keyed precision tiers (commit 3839e529) forced FP32 ops off the FP16-only ANE on plain production convnets, not just transformers. b18c384nbt ran ~2.6x slower on the ANE path (160 vs 416 visits/s) with no accuracy benefit. The dominant cost is the per-block global-pooling FP32 (non-spatial), which breaks the ANE pipeline once per gpool-residual block; conv-FP32 is secondary. Add a recursive blocksContainTransformer() helper and gate all three escalations (full-FP32, non-spatial-FP32, conv-FP32) on transformer-block presence. Convnets now run pure FP16 on the ANE (the long-standing pre-tier path); for transformer models the added "&& hasTransformer" is always true, so their emitted MIL is byte-identical and behavior is unchanged. Verified on the ANE FP16 path: b18c384nbt testgpuerror passes (winrate 99%=0.57%, max=0.87%) and recovers full throughput (424 visits/s); b28c512nbt passes (99%=0.41%); all 4 transformer test models x sizes 9/13/19 pass with numbers byte-identical to before; runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../katagocoreml/src/builder/MILBuilder.cpp | 45 ++++++++++++++----- .../katagocoreml/src/builder/MILBuilder.hpp | 18 ++++---- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index b4974aa39..09ab365ff 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -23,6 +23,25 @@ struct ScopedFp32 { ScopedFp32(const ScopedFp32&) = delete; ScopedFp32& operator=(const ScopedFp32&) = delete; }; + +// True if any block in this list is a transformer (attention/FFN), recursing into nested-bottleneck +// blocks (which can themselves contain transformer blocks). Used to scope the off-ANE FP32 +// escalations to transformer trunks only. +bool blocksContainTransformer(const std::vector& blocks) { + for (const auto& entry : blocks) { + if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND || + entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + return true; + } + if (entry.block_kind == NESTED_BOTTLENECK_BLOCK_KIND) { + const auto& nbt = std::get(*entry.block); + if (blocksContainTransformer(nbt.blocks)) { + return true; + } + } + } + return false; +} } // namespace MILBuilder::MILBuilder(const KataGoModelDesc& model, @@ -46,22 +65,28 @@ MILBuilder::MILBuilder(const KataGoModelDesc& model, : CoreML::Specification::MILSpec::DataType::FLOAT32) , m_ops(board_x_size, board_y_size, optimize_identity_mask) , m_var_counter(0) { - // Precision tiers in FP16 mode (the ANE accumulates FP16 in FP16; FP32 ops run off the FP16-only - // ANE). NARROW transformer trunks are unreliable on the FP16 ANE: their policy/value metrics sit - // right on the testgpuerror thresholds and no partial-FP32 config passes all board sizes (partial - // FP32 leaves a noisy FP16 spatial stream). So build narrow trunks FULLY in FP32 (off-ANE, but - // cheap since narrow models are small; correct because it equals the FP32 reference). Weights are - // stored FP32 via per-weight serialization. Wider trunks use partial FP32: non-spatial (matmuls + - // pooling) always FP32; convs FP32 only for very wide trunks (kept on the ANE for narrower ones). + // Precision in FP16 mode. The ANE accumulates FP16 in FP16, so any FP32 op runs OFF the FP16-only + // ANE (on CPU/GPU), breaking the ANE pipeline. These off-ANE FP32 escalations are applied ONLY to + // transformer trunks, whose attention blocks widen the activation range enough to overflow FP16 + // accumulation. Plain convnets stay PURE FP16 on the ANE -- the long-standing pre-tier path, verified + // to pass testgpuerror (b18c384nbt, b28c512nbt) and ~2.6x faster than forcing their per-block global + // pooling and convs to FP32 (measured: the per-block pooling round-trips, not the convs, dominate the + // slowdown). For transformers: + // - NARROW trunks (<256ch) build FULLY in FP32: their policy/value metrics sit right on the + // testgpuerror thresholds and no partial-FP32 config passes all board sizes (partial FP32 leaves a + // noisy FP16 spatial stream). Off-ANE but cheap since narrow; equals the FP32 reference. Weights + // stored FP32 (per-weight serialization). + // - WIDER trunks use partial FP32: non-spatial (matmuls + pooling) always, convs only for >=320ch. const int trunkChannels = model.trunk.trunk_num_channels; - const bool full_fp32 = use_fp16 && trunkChannels < FULL_FP32_MAX_TRUNK_CHANNELS; + const bool hasTransformer = blocksContainTransformer(model.trunk.blocks); + const bool full_fp32 = use_fp16 && hasTransformer && trunkChannels < FULL_FP32_MAX_TRUNK_CHANNELS; if (full_fp32) { m_use_fp16 = false; m_use_fp16_io = false; m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; } - m_nonspatial_fp32 = m_use_fp16; - m_conv_fp32 = m_use_fp16 && trunkChannels >= CONV_FP32_MIN_TRUNK_CHANNELS; + m_nonspatial_fp32 = m_use_fp16 && hasTransformer; + m_conv_fp32 = m_use_fp16 && hasTransformer && trunkChannels >= CONV_FP32_MIN_TRUNK_CHANNELS; } void MILBuilder::setBatchDimension(CoreML::Specification::MILSpec::TensorType* tensor_type) { diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index fe63b442f..e38afb05e 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -43,14 +43,16 @@ class MILBuilder { bool m_optimize_identity_mask; bool m_use_fp16; bool m_use_fp16_io; - // FP32 in FP16 mode follows KataGo's FP16 convention (spatial convs FP16, non-spatial FP32), but - // FP32 ops run off the FP16-only ANE, so convs are channel-gated to only the wide trunks that - // need it. RMSNorm reductions: always FP32 (cheap, needed by all). Non-spatial matmuls+pooling: - // always FP32 (every width needs it at some board size). Convs: FP32 only for wide trunks. - static constexpr int CONV_FP32_MIN_TRUNK_CHANNELS = 320; // convs run FP32 at/above this width - static constexpr int FULL_FP32_MAX_TRUNK_CHANNELS = 256; // trunks below this build fully FP32 - bool m_nonspatial_fp32 = false; // = m_use_fp16 (matmuls + global pooling) - bool m_conv_fp32 = false; // = m_use_fp16 && trunk_channels >= CONV_FP32_MIN_... + // FP32-in-FP16-mode escalations all run off the FP16-only ANE, so they apply ONLY to transformer + // trunks (attention widens activation range, overflowing FP16 conv/matmul/pooling accumulation). + // Plain convnets run pure FP16 on the ANE -- the long-standing pre-tier path, verified to pass + // testgpuerror (b18c384nbt) and ~2.3x faster than forcing their per-block global pooling to FP32. + // For transformers: narrow trunks (<256) build fully FP32; wider ones use non-spatial FP32 (matmuls + + // pooling) plus, for very wide trunks (>=320), conv FP32. RMSNorm reductions: FP32 when m_use_fp16. + static constexpr int CONV_FP32_MIN_TRUNK_CHANNELS = 320; // transformer convs run FP32 at/above this width + static constexpr int FULL_FP32_MAX_TRUNK_CHANNELS = 256; // transformer trunks below this build fully FP32 + bool m_nonspatial_fp32 = false; // = m_use_fp16 && hasTransformer (matmuls + global pooling) + bool m_conv_fp32 = false; // = m_use_fp16 && hasTransformer && trunk_channels >= CONV_FP32_MIN_... int m_min_batch_size; int m_max_batch_size; CoreML::Specification::MILSpec::DataType m_weight_dtype; From 145902423e700bbb3348bba45346ce9ade711d46 Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Tue, 2 Jun 2026 19:08:35 +0800 Subject: [PATCH 08/10] Fix RoPE cos/sin table buffer leak in Metal backend makeRopeTables allocated cosBuf/sinBuf with UnsafeMutablePointer.allocate and handed them to Data(floatsNoCopy:), which uses deallocator: .none, so the buffers were never freed -- a leak on every graph build (per attention block, per board size). Unlike the other floatsNoCopy callers (weights/gamma/beta), which point at C++-descriptor memory that lives for the model's lifetime, these tables have no persistent owner. Switch to managed [Float32] arrays and copy into the Data via Data(buffer:) so MPSGraph owns the bytes -- avoids both the leak and a use-after-free that a naive deallocate() on the no-copy path would cause. Output-neutral: testgpuerror on the GQA + learnable-RoPE model (b7c96h6kv3qk32v16tflrs, board 19) vs Eigen FP32 reference matches to 0.00028% max winrate error over 2247 positions. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/metallayers.swift | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift index dc4d22d53..96667f3ed 100644 --- a/cpp/neuralnet/metallayers.swift +++ b/cpp/neuralnet/metallayers.swift @@ -1438,8 +1438,11 @@ struct TransformerAttentionBlock { nHeads: Int, seq: Int, numPairs: Int, nnX: Int, nnY: Int, qHeadDim: Int, dataType: MPSDataType, kvIndexForHead: (Int) -> Int) -> (MPSGraphTensor, MPSGraphTensor) { let count = nHeads * seq * numPairs - let cosBuf = UnsafeMutablePointer.allocate(capacity: count) - let sinBuf = UnsafeMutablePointer.allocate(capacity: count) + // Managed arrays (freed on return). Unlike the weight constants elsewhere, which point at + // C++-owned descriptor memory and so use floatsNoCopy, these tables have no persistent owner; + // we copy them into the Data below so MPSGraph owns the bytes (avoids a leak / use-after-free). + var cosBuf = [Float32](repeating: 0, count: count) + var sinBuf = [Float32](repeating: 0, count: count) let numPairsPerDim = numPairs / 2 let dimHalf = qHeadDim / 2 for h in 0.. Date: Tue, 2 Jun 2026 21:54:45 +0800 Subject: [PATCH 09/10] Guard non-SwiGLU transformer FFN in Metal backend The Metal forward pass (metallayers.swift TransformerFFNBlock) only implements the SwiGLU path (SiLU(linear1) * gate). A non-SwiGLU model carries no gate weights, so building the Swift descriptor from the empty linearGate would crash obscurely (or silently misbehave). Eigen (eigenbackend.cpp) and CoreML (katagocoreml MILBuilder) both throw a clear "non-SwiGLU transformer FFN not supported" error in this case; the Metal GPU path had no such guard. Add the matching StringError at the FFN descriptor conversion so all three backends fail loudly and consistently. No behavior change for any current model (all use useSwiGLU=true): the guard sits on an untaken path. Verified the SwiGLU model b10c384h6nbttflrs still passes testgpuerror on both GPU (0.00005% winrate) and ANE (unchanged from baseline); runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/metalbackend.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cpp/neuralnet/metalbackend.cpp b/cpp/neuralnet/metalbackend.cpp index 75a4e9a93..10ee50b62 100644 --- a/cpp/neuralnet/metalbackend.cpp +++ b/cpp/neuralnet/metalbackend.cpp @@ -256,6 +256,11 @@ SWTransformerAttentionBlockDesc transformerAttentionBlockDescToSwift(const Trans /// Convert a transformer FFN block description from C++ to Swift SWTransformerFFNBlockDesc transformerFFNBlockDescToSwift(const TransformerFFNDesc* desc) { + // The Metal forward pass (metallayers.swift TransformerFFNBlock) only implements the SwiGLU path + // (SiLU(linear1) * gate); a non-SwiGLU model has no gate weights, so guard here as Eigen and CoreML + // do (eigenbackend.cpp / katagocoreml MILBuilder) instead of crashing on the empty gate descriptor. + if(!desc->useSwiGLU) + throw StringError(desc->name + ": non-SwiGLU transformer FFN not supported in Metal backend"); SWTransformerRMSNormDesc preLN = transformerRMSNormDescToSwift(&desc->preLN); SWMatMulLayerDesc linear1 = matMulLayerDescToSwift(&desc->linear1); SWMatMulLayerDesc linearGate = matMulLayerDescToSwift(&desc->linearGate); From 39f82f6d2ef013e73ac56968c7bd8cd8e4b9576e Mon Sep 17 00:00:00 2001 From: Chin-Chang Yang <2770271+ChinChangYang@users.noreply.github.com> Date: Tue, 2 Jun 2026 22:31:04 +0800 Subject: [PATCH 10/10] Use named constant for trunk norm kind in Metal backend The trunk-tip dispatch in metallayers.swift compared trunkNormKind against the literal 1, while the rest of the codebase uses the named constants from desc.h (TRUNK_NORM_KIND_STANDARD/_RMSNORM). Add matching Swift constants and use TRUNK_NORM_KIND_RMSNORM at the comparison site. Pure literal-to-named-constant rename; no behavior change. Verified both branches still pass testgpuerror at GPU-level accuracy: RMSNorm tip (b10c384h6nbttflrs) 0.00005% winrate, BatchNorm tip (b7c96h6kv3 GQA) 0.00003%; runtests and runnnlayertests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- cpp/neuralnet/metallayers.swift | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift index 96667f3ed..e1324df96 100644 --- a/cpp/neuralnet/metallayers.swift +++ b/cpp/neuralnet/metallayers.swift @@ -2014,6 +2014,10 @@ class SGFMetadataEncoder { // MARK: - Trunk +/// Trunk-tip normalization kind, mirroring desc.h TRUNK_NORM_KIND_* (the value is serialized in the model). +let TRUNK_NORM_KIND_STANDARD = 0 // BatchNorm or BiasMask (existing) +let TRUNK_NORM_KIND_RMSNORM = 1 // RMSNorm + /// A class that describes a trunk for a neural network public class SWTrunkDesc { let version: Int @@ -2184,9 +2188,8 @@ struct Trunk { nnYLen: nnYLen, optimizeIdentityMask: optimizeIdentityMask) - // TRUNK_NORM_KIND_RMSNORM == 1: trunk tip uses RMSNorm with a fused activation. - // Otherwise (standard): BatchNorm followed by a separate activation. - if descriptor.trunkNormKind == 1 { + // RMSNorm trunk tip uses a fused activation; standard uses BatchNorm followed by a separate activation. + if descriptor.trunkNormKind == TRUNK_NORM_KIND_RMSNORM { let trunkTipRMSNorm = TrunkRMSNormLayer( graph: graph, sourceTensor: blocks.resultTensor,