diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp index db0c6c4b1..09ab365ff 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.cpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.cpp @@ -4,12 +4,46 @@ #include "MILBuilder.hpp" #include "MILBlob/Fp16.hpp" #include +#include // Include generated protobuf headers #include "MIL.pb.h" namespace katagocoreml { +namespace { +// RAII: set a dtype slot to FLOAT32 for the current scope and restore it on exit. Used to emit a +// sub-region of ops in FP32 inside an otherwise-FP16 model. +struct ScopedFp32 { + CoreML::Specification::MILSpec::DataType& slot; + CoreML::Specification::MILSpec::DataType saved; + explicit ScopedFp32(CoreML::Specification::MILSpec::DataType& s) + : slot(s), saved(s) { s = CoreML::Specification::MILSpec::DataType::FLOAT32; } + ~ScopedFp32() { slot = saved; } + ScopedFp32(const ScopedFp32&) = delete; + ScopedFp32& operator=(const ScopedFp32&) = delete; +}; + +// True if any block in this list is a transformer (attention/FFN), recursing into nested-bottleneck +// blocks (which can themselves contain transformer blocks). Used to scope the off-ANE FP32 +// escalations to transformer trunks only. +bool blocksContainTransformer(const std::vector& blocks) { + for (const auto& entry : blocks) { + if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND || + entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + return true; + } + if (entry.block_kind == NESTED_BOTTLENECK_BLOCK_KIND) { + const auto& nbt = std::get(*entry.block); + if (blocksContainTransformer(nbt.blocks)) { + return true; + } + } + } + return false; +} +} // namespace + MILBuilder::MILBuilder(const KataGoModelDesc& model, int board_x_size, int board_y_size, @@ -30,7 +64,30 @@ MILBuilder::MILBuilder(const KataGoModelDesc& model, ? CoreML::Specification::MILSpec::DataType::FLOAT16 : CoreML::Specification::MILSpec::DataType::FLOAT32) , m_ops(board_x_size, board_y_size, optimize_identity_mask) - , m_var_counter(0) {} + , m_var_counter(0) { + // Precision in FP16 mode. The ANE accumulates FP16 in FP16, so any FP32 op runs OFF the FP16-only + // ANE (on CPU/GPU), breaking the ANE pipeline. These off-ANE FP32 escalations are applied ONLY to + // transformer trunks, whose attention blocks widen the activation range enough to overflow FP16 + // accumulation. Plain convnets stay PURE FP16 on the ANE -- the long-standing pre-tier path, verified + // to pass testgpuerror (b18c384nbt, b28c512nbt) and ~2.6x faster than forcing their per-block global + // pooling and convs to FP32 (measured: the per-block pooling round-trips, not the convs, dominate the + // slowdown). For transformers: + // - NARROW trunks (<256ch) build FULLY in FP32: their policy/value metrics sit right on the + // testgpuerror thresholds and no partial-FP32 config passes all board sizes (partial FP32 leaves a + // noisy FP16 spatial stream). Off-ANE but cheap since narrow; equals the FP32 reference. Weights + // stored FP32 (per-weight serialization). + // - WIDER trunks use partial FP32: non-spatial (matmuls + pooling) always, convs only for >=320ch. + const int trunkChannels = model.trunk.trunk_num_channels; + const bool hasTransformer = blocksContainTransformer(model.trunk.blocks); + const bool full_fp32 = use_fp16 && hasTransformer && trunkChannels < FULL_FP32_MAX_TRUNK_CHANNELS; + if (full_fp32) { + m_use_fp16 = false; + m_use_fp16_io = false; + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + m_nonspatial_fp32 = m_use_fp16 && hasTransformer; + m_conv_fp32 = m_use_fp16 && hasTransformer && trunkChannels >= CONV_FP32_MIN_TRUNK_CHANNELS; +} void MILBuilder::setBatchDimension(CoreML::Specification::MILSpec::TensorType* tensor_type) { auto* dim = tensor_type->add_dimensions(); @@ -212,8 +269,10 @@ void MILBuilder::addConstOp(CoreML::Specification::MILSpec::Block* block, const std::string& name, const std::vector& data, const std::vector& shape) { - // Register weight for blob storage - m_ops.registerWeight(name, data, shape); + // Register weight for blob storage. Mark FP32 storage when this const is declared FP32 (e.g. + // inside an FP32 sub-region of an otherwise-FP16 model) so storage matches the declared type. + m_ops.registerWeight(name, data, shape, + m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT32); // Add const operation auto* op = block->add_operations(); @@ -328,7 +387,11 @@ void MILBuilder::addFloatScalarConstOp(CoreML::Specification::MILSpec::Block* bl val_type->set_datatype(m_weight_dtype); val_type->set_rank(0); - if (m_use_fp16) { + // Key the storage format off the DECLARED dtype (m_weight_dtype), not the global m_use_fp16: + // a temporarily-flipped FP32 sub-region (m_weight_dtype=FLOAT32 while m_use_fp16 stays true) + // must store FP32 floats, or CoreML rejects the model ("storage and type have different number + // of elements"). For all non-flipped calls m_weight_dtype tracks m_use_fp16, so this is a no-op. + if (m_weight_dtype == CoreML::Specification::MILSpec::DataType::FLOAT16) { // For FP16, use bytes storage with FP16 representation MILBlob::Fp16 fp16_val = MILBlob::Fp16::FromFloat(value); std::string bytes_data(reinterpret_cast(&fp16_val.bytes), sizeof(fp16_val.bytes)); @@ -426,6 +489,68 @@ void MILBuilder::addCastOp(CoreML::Specification::MILSpec::Block* block, } } +std::string MILBuilder::castFixed(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& dtype, + const std::vector& dims) { + std::string out = genVarName(input + "_cast"); + std::string dtName = out + "_dt"; + { + auto* op = block->add_operations(); + op->set_type("const"); + auto& na = (*op->mutable_attributes())["name"]; + na.mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + na.mutable_immediatevalue()->mutable_tensor()->mutable_strings()->add_values(dtName); + auto& va = (*op->mutable_attributes())["val"]; + va.mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + va.mutable_immediatevalue()->mutable_tensor()->mutable_strings()->add_values(dtype); + auto* o = op->add_outputs(); + o->set_name(dtName); + o->mutable_type()->mutable_tensortype()->set_datatype(CoreML::Specification::MILSpec::DataType::STRING); + } + auto* op = block->add_operations(); + op->set_type("cast"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(input); + (*op->mutable_inputs())["dtype"].add_arguments()->set_name(dtName); + auto* o = op->add_outputs(); + o->set_name(out); + auto* tt = o->mutable_type()->mutable_tensortype(); + tt->set_datatype(dtype == "fp32" ? CoreML::Specification::MILSpec::DataType::FLOAT32 + : CoreML::Specification::MILSpec::DataType::FLOAT16); + tt->set_rank(static_cast(dims.size())); + for (int64_t d : dims) { + if (d < 0) tt->add_dimensions()->mutable_unknown()->set_variadic(false); + else tt->add_dimensions()->mutable_constant()->set_size(d); + } + return out; +} + +void MILBuilder::addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& mask, + int channels, + const std::string& output, + bool valueVariant) { + auto pool = [&](const std::string& in, const std::string& msk, const std::string& out) { + if (valueVariant) addGlobalPoolingValueOps(block, in, msk, channels, out); + else addGlobalPoolingOps(block, in, msk, channels, out); + }; + // Non-spatial per KataGo's FP16 convention -> FP32 (the FP16 spatial sum over H*W loses too much + // precision at larger board sizes). No addConstOp in the pooling, so flipping m_weight_dtype is + // safe. Cast input/mask up, pool in FP32, cast the [N, channels*3] features back to FP16. + if (!m_nonspatial_fp32) { + pool(input, mask, output); + return; + } + std::string in32 = castFixed(block, input, "fp32", {-1, channels, m_board_y_size, m_board_x_size}); + std::string mask32 = m_optimize_identity_mask + ? mask + : castFixed(block, mask, "fp32", {-1, 1, m_board_y_size, m_board_x_size}); + std::string out32 = genVarName(output + "_f32"); + { ScopedFp32 g(m_weight_dtype); pool(in32, mask32, out32); } + addCastOp(block, out32, output, "fp16", {-1, channels * 3}); +} + void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, @@ -566,6 +691,21 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, tt->add_dimensions()->mutable_constant()->set_size(4); } + // Channel-gated FP32 convs. The ANE accumulates FP16 convs in FP16, which loses too much + // precision for WIDE trunks and fails testgpuerror at large board sizes (validated: 384ch + // fails, <=256ch is fine FP16-on-ANE). For wide trunks (>= threshold) run convs in FP32 (weights + // cast up at runtime, stored fp16). FP32 convs can't run on the fp16-only ANE, so only the wide + // models that actually need it pay that off-ANE cost; narrow models keep convs on the ANE. + const bool convFp32 = m_conv_fp32; + std::string convX = input, convW = weight_name, convOut = output; + auto savedConvDtype = m_weight_dtype; + if (convFp32) { + convX = castFixed(block, input, "fp32", {-1, layer.in_channels, m_board_y_size, m_board_x_size}); + convW = castFixed(block, weight_name, "fp32", layer.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + convOut = output + "_cf32"; + } + // Add conv operation referencing all const parameters auto* op = block->add_operations(); op->set_type("conv"); @@ -577,12 +717,12 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, inputs["pad"].add_arguments()->set_name(pad_name); inputs["pad_type"].add_arguments()->set_name(pad_type_name); inputs["strides"].add_arguments()->set_name(strides_name); - inputs["weight"].add_arguments()->set_name(weight_name); - inputs["x"].add_arguments()->set_name(input); + inputs["weight"].add_arguments()->set_name(convW); + inputs["x"].add_arguments()->set_name(convX); // Output with dimensions [batch, out_channels, height, width] auto* out = op->add_outputs(); - out->set_name(output); + out->set_name(convOut); auto* out_type = out->mutable_type()->mutable_tensortype(); out_type->set_datatype(m_weight_dtype); out_type->set_rank(4); @@ -590,6 +730,11 @@ void MILBuilder::addConvOp(CoreML::Specification::MILSpec::Block* block, out_type->add_dimensions()->mutable_constant()->set_size(layer.out_channels); out_type->add_dimensions()->mutable_constant()->set_size(m_board_y_size); out_type->add_dimensions()->mutable_constant()->set_size(m_board_x_size); + + if (convFp32) { + m_weight_dtype = savedConvDtype; + addCastOp(block, convOut, output, "fp16", {-1, layer.out_channels, m_board_y_size, m_board_x_size}); + } } // Helper: Set output tensor type with 4D shape [batch, C, H, W] @@ -732,6 +877,63 @@ void MILBuilder::addBatchNormActivationOps(CoreML::Specification::MILSpec::Block setTensorOutput4D(op, output, bn.num_channels, m_board_y_size, m_board_x_size); } else if (act.activation_type == ActivationType::Mish) { addMishOps(block, bn_output, output, 4, bn.num_channels); + } else if (act.activation_type == ActivationType::Silu) { + addSiluOps(block, bn_output, output, 4, bn.num_channels); + } +} + +void MILBuilder::addSiluOps(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& output, + int rank, + int channels) { + // SiLU / Swish: x * sigmoid(x) + auto setOutputType = [this, rank, channels](CoreML::Specification::MILSpec::Operation* op, const std::string& name) { + auto* out = op->add_outputs(); + out->set_name(name); + auto* out_type = out->mutable_type()->mutable_tensortype(); + out_type->set_datatype(m_weight_dtype); + out_type->set_rank(rank); + setBatchDimension(out_type); + out_type->add_dimensions()->mutable_constant()->set_size(channels); + if (rank == 4) { + out_type->add_dimensions()->mutable_constant()->set_size(m_board_y_size); + out_type->add_dimensions()->mutable_constant()->set_size(m_board_x_size); + } + }; + + std::string sig = output + "_sigmoid"; + { + auto* op = block->add_operations(); + op->set_type("sigmoid"); + auto& inputs = *op->mutable_inputs(); + inputs["x"].add_arguments()->set_name(input); + setOutputType(op, sig); + } + { + auto* op = block->add_operations(); + op->set_type("mul"); + auto& inputs = *op->mutable_inputs(); + inputs["x"].add_arguments()->set_name(input); + inputs["y"].add_arguments()->set_name(sig); + setOutputType(op, output); + } +} + +void MILBuilder::setShape(CoreML::Specification::MILSpec::Operation* op, + const std::string& name, + const std::vector& dims) { + auto* out = op->add_outputs(); + out->set_name(name); + auto* t = out->mutable_type()->mutable_tensortype(); + t->set_datatype(m_weight_dtype); + t->set_rank(static_cast(dims.size())); + for (int64_t d : dims) { + auto* dim = t->add_dimensions(); + if (d < 0) + dim->mutable_unknown()->set_variadic(false); + else + dim->mutable_constant()->set_size(d); } } @@ -887,23 +1089,38 @@ void MILBuilder::addMatMulOp(CoreML::Specification::MILSpec::Block* block, CoreML::Specification::MILSpec::DataType::BOOL); } + // Non-spatial matmul in FP32 (KataGo FP16 convention; weights cast up at runtime, stored fp16). + std::string mmIn = input, mmW = weight_name, mmOut = output; + auto savedMmDtype = m_weight_dtype; + if (m_nonspatial_fp32) { + mmIn = castFixed(block, input, "fp32", {-1, layer.in_channels}); + mmW = castFixed(block, weight_name, "fp32", layer.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + mmOut = output + "_mmf32"; + } + // Add matmul operation auto* op = block->add_operations(); op->set_type("matmul"); auto& inputs = *op->mutable_inputs(); inputs["transpose_x"].add_arguments()->set_name(transpose_x_name); inputs["transpose_y"].add_arguments()->set_name(transpose_y_name); - inputs["x"].add_arguments()->set_name(input); - inputs["y"].add_arguments()->set_name(weight_name); + inputs["x"].add_arguments()->set_name(mmIn); + inputs["y"].add_arguments()->set_name(mmW); // Output with 2D shape [batch, out_channels] auto* out = op->add_outputs(); - out->set_name(output); + out->set_name(mmOut); auto* out_type = out->mutable_type()->mutable_tensortype(); out_type->set_datatype(m_weight_dtype); out_type->set_rank(2); setBatchDimension(out_type); out_type->add_dimensions()->mutable_constant()->set_size(layer.out_channels); + + if (m_nonspatial_fp32) { + m_weight_dtype = savedMmDtype; + addCastOp(block, mmOut, output, "fp16", {-1, layer.out_channels}); + } } void MILBuilder::addMatBiasOp(CoreML::Specification::MILSpec::Block* block, @@ -964,7 +1181,9 @@ void MILBuilder::addLinearOp(CoreML::Specification::MILSpec::Block* block, std::vector bias_shape = {static_cast(bias.num_channels)}; addConstOp(block, bias_name, bias.weights, bias_shape); - // Add linear operation + // NOTE: the MIL `linear` op requires const weight/bias, so the runtime-cast-to-FP32 trick can't + // be applied here (unlike `matmul`). Value-head linear stays FP16; if a model ever needs it in + // FP32, rewrite as matmul+add (matmul accepts cast inputs). auto* op = block->add_operations(); op->set_type("linear"); auto& inputs = *op->mutable_inputs(); @@ -1637,6 +1856,636 @@ void MILBuilder::addGlobalPoolingValueOps(CoreML::Specification::MILSpec::Block* // Network Component Builders // ============================================================================ +// --------------------------------------------------------------------------- +// Transformer blocks (MIL). Layout is NCHW [B, C, H, W]; spatial positions +// (H*W, ordered y*W+x) are treated as the attention sequence. RoPE is applied +// via a fixed pair-rotation matmul plus host-precomputed cos/sin tables, which +// keeps every tensor rank <= 4 (ANE-friendly). +// --------------------------------------------------------------------------- + +std::string MILBuilder::addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerRMSNormDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int H = m_board_y_size, W = m_board_x_size; + auto emit2 = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + // RMSNorm reduction core: square -> mean over channels -> rsqrt. In FP16 mode compute this + // core in FP32 (cast input up, flip the working dtype so the core's op outputs + eps scalar are + // FP32, then cast 1/rms back down). The FP16 channel reduction loses too much precision on the + // ANE; only this core is FP32 - the scaling/weight/mask below stay FP16. No addConstOp lives in + // the flipped window, so weight serialization is unaffected. + auto savedDtype = m_weight_dtype; + std::string sqSrc = input; + if (m_use_fp16) { + sqSrc = genVarName(prefix + "_in32"); + addCastOp(block, input, sqSrc, "fp32", {-1, C, H, W}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + std::string sq = genVarName(prefix + "_sq"); + emit2("mul", sqSrc, sqSrc, sq, {-1, C, H, W}); + // meanSq = reduce_mean(sq, axes=[1]) over channels. reduce_mean (not reduce_sum) is used so + // the accumulator stays ~O(activation^2) instead of summing hundreds of channels, which can + // overflow FP16 (and the FP16 accumulation on ANE) for large activations. + std::string meanSq = genVarName(prefix + "_meansq"); + { + std::string axesName = meanSq + "_axes"; + std::string keepName = meanSq + "_keep"; + addIntArrayConstOp(block, axesName, {1}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanSq, {-1, 1, H, W}); + } + // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. + std::string epsName = prefix + "_eps"; + addFloatScalarConstOp(block, epsName, desc.epsilon); + std::string invCore = genVarName(prefix + "_inv"); + { + auto* op = block->add_operations(); + op->set_type("rsqrt"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); + (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); + setShape(op, invCore, {-1, 1, H, W}); + } + std::string inv = invCore; + if (m_use_fp16) { + m_weight_dtype = savedDtype; + inv = genVarName(prefix + "_inv16"); + addCastOp(block, invCore, inv, "fp16", {-1, 1, H, W}); + } + std::string normalized = genVarName(prefix + "_norm"); + emit2("mul", input, inv, normalized, {-1, C, H, W}); + std::string weightName = prefix + "_weight"; + addConstOp(block, weightName, desc.weight, {1, static_cast(C), 1, 1}); + std::string scaled = genVarName(prefix + "_scaled"); + emit2("mul", normalized, weightName, scaled, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + emit2("mul", scaled, mask, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const RMSNormLayerDesc& desc, + const ActivationLayerDesc& act, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int H = m_board_y_size, W = m_board_x_size; + auto emit2 = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + auto reduceSum = [&](const std::string& x, const std::string& out, const std::vector& axes, + const std::vector& dims) { + std::string axesName = out + "_axes"; + std::string keepName = out + "_keep"; + addIntArrayConstOp(block, axesName, axes); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_sum"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, out, dims); + }; + + // Variance core (mask -> square -> reduce -> rsqrt) in FP32 when in FP16 mode. The trunk-tip + // norm in particular reduces over many elements and loses too much precision in FP16 on the + // ANE; compute the core in FP32 and cast 1/rms back to FP16. Only the core is FP32 - gamma/beta, + // the activation and the final mask below stay FP16. No addConstOp lives in the flipped window. + auto savedDtype = m_weight_dtype; + std::string tinput = input; + std::string tmask = mask; + if (m_use_fp16) { + tinput = genVarName(prefix + "_in32"); + addCastOp(block, input, tinput, "fp32", {-1, C, H, W}); + tmask = genVarName(prefix + "_mask32"); + addCastOp(block, mask, tmask, "fp32", {-1, 1, H, W}); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + std::string masked = genVarName(prefix + "_premask"); + emit2("mul", tinput, tmask, masked, {-1, C, H, W}); + std::string sq = genVarName(prefix + "_sq"); + emit2("mul", masked, masked, sq, {-1, C, H, W}); + + std::string meanSq; + std::vector denomDims; + if (desc.spatial) { + // Mean of squares over valid positions and channels. A reduce_sum over C*H*W elements + // overflows FP16 (e.g. trunk tip with large activations on ANE -> inf -> rsqrt 0 -> + // collapse). Instead take reduce_mean over all of C,H,W (masked positions are zero) and + // rescale by totalPositions/validCount to restrict the mean to valid positions. + std::string meanAll = genVarName(prefix + "_meanall"); + { + std::string axesName = meanAll + "_axes", keepName = meanAll + "_keep"; + addIntArrayConstOp(block, axesName, {1, 2, 3}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanAll, {-1, 1, 1, 1}); + } + std::string count = genVarName(prefix + "_count"); + reduceSum(tmask, count, {1, 2, 3}, {-1, 1, 1, 1}); // valid positions (<= H*W, no overflow) + std::string totalPosName = prefix + "_totalpos"; + addFloatScalarConstOp(block, totalPosName, static_cast(H * W)); + std::string scaleF = genVarName(prefix + "_scalef"); + emit2("real_div", totalPosName, count, scaleF, {-1, 1, 1, 1}); // totalPos / validCount + meanSq = genVarName(prefix + "_meansq"); + emit2("mul", meanAll, scaleF, meanSq, {-1, 1, 1, 1}); + denomDims = {-1, 1, 1, 1}; + } else { + meanSq = genVarName(prefix + "_meansq"); + std::string axesName = meanSq + "_axes"; + std::string keepName = meanSq + "_keep"; + addIntArrayConstOp(block, axesName, {1}); + addBoolScalarConstOp(block, keepName, true); + auto* op = block->add_operations(); + op->set_type("reduce_mean"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(sq); + (*op->mutable_inputs())["axes"].add_arguments()->set_name(axesName); + (*op->mutable_inputs())["keep_dims"].add_arguments()->set_name(keepName); + setShape(op, meanSq, {-1, 1, H, W}); + denomDims = {-1, 1, H, W}; + } + + // MIL rsqrt computes 1/sqrt(x + epsilon); supply epsilon directly. + std::string epsName = prefix + "_eps"; + addFloatScalarConstOp(block, epsName, desc.epsilon); + std::string invCore = genVarName(prefix + "_inv"); + { + auto* op = block->add_operations(); + op->set_type("rsqrt"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(meanSq); + (*op->mutable_inputs())["epsilon"].add_arguments()->set_name(epsName); + setShape(op, invCore, denomDims); + } + std::string inv = invCore; + if (m_use_fp16) { + m_weight_dtype = savedDtype; + inv = genVarName(prefix + "_inv16"); + addCastOp(block, invCore, inv, "fp16", denomDims); + } + std::string normalized = genVarName(prefix + "_norm"); + emit2("mul", input, inv, normalized, {-1, C, H, W}); + std::string gammaName = prefix + "_gamma"; + std::string betaName = prefix + "_beta"; + addConstOp(block, gammaName, desc.gamma, {1, static_cast(C), 1, 1}); + addConstOp(block, betaName, desc.beta, {1, static_cast(C), 1, 1}); + std::string scaled = genVarName(prefix + "_scaled"); + emit2("mul", normalized, gammaName, scaled, {-1, C, H, W}); + std::string biased = genVarName(prefix + "_biased"); + emit2("add", scaled, betaName, biased, {-1, C, H, W}); + + std::string activated; + if (act.activation_type == ActivationType::Silu) { + activated = genVarName(prefix + "_act"); + addSiluOps(block, biased, activated, 4, C); + } else if (act.activation_type == ActivationType::Mish) { + activated = genVarName(prefix + "_act"); + addMishOps(block, biased, activated, 4, C); + } else if (act.activation_type == ActivationType::ReLU) { + activated = genVarName(prefix + "_act"); + auto* op = block->add_operations(); + op->set_type("relu"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(biased); + setShape(op, activated, {-1, C, H, W}); + } else { + activated = biased; + } + std::string out = genVarName(prefix + "_out"); + emit2("mul", activated, mask, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerAttentionBlockDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.q_proj.in_channels; + const int H = m_board_y_size, W = m_board_x_size; + const int seq = H * W; + const int numHeads = desc.num_heads, numKVHeads = desc.num_kv_heads; + const int qHeadDim = desc.q_head_dim, vHeadDim = desc.v_head_dim; + const int qTotal = numHeads * qHeadDim, kTotal = numKVHeads * qHeadDim, vTotal = numKVHeads * vHeadDim; + + auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, + const std::vector& dims) { + std::string shapeName = out + "_shape"; + addIntArrayConstOp(block, shapeName, shapeVals); + auto* op = block->add_operations(); + op->set_type("reshape"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["shape"].add_arguments()->set_name(shapeName); + setShape(op, out, dims); + }; + auto transpose = [&](const std::string& in, const std::string& out, const std::vector& perm, + const std::vector& dims) { + std::string permName = out + "_perm"; + addIntArrayConstOp(block, permName, perm); + auto* op = block->add_operations(); + op->set_type("transpose"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["perm"].add_arguments()->set_name(permName); + setShape(op, out, dims); + }; + auto matmul = [&](const std::string& x, const std::string& y, const std::string& out, + const std::vector& dims, bool transX, bool transY) { + std::string txName = out + "_tx", tyName = out + "_ty"; + addBoolScalarConstOp(block, txName, transX); + addBoolScalarConstOp(block, tyName, transY); + auto* op = block->add_operations(); + op->set_type("matmul"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + (*op->mutable_inputs())["transpose_x"].add_arguments()->set_name(txName); + (*op->mutable_inputs())["transpose_y"].add_arguments()->set_name(tyName); + setShape(op, out, dims); + }; + auto binary = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string normed = addTransformerRMSNorm(block, input, desc.pre_ln, mask, prefix + "_ln"); + std::string nhwc = genVarName(prefix + "_nhwc"); + transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); + std::string x2d = genVarName(prefix + "_x2d"); + reshape(nhwc, x2d, {-1, C}, {-1, C}); + // Q/K/V projection matmuls in FP32 (non-spatial, per KataGo's FP16 convention): they reduce over + // C channels and the ANE's FP16 accumulation loses too much precision for wide models. Weights + // stay fp16-stored (cast up at runtime); output cast back to FP16 for the FP16 head reshapes. + auto proj = [&](const MatMulLayerDesc& w, const std::string& nm, int total) { + std::string wName = nm + "_w"; + addConstOp(block, wName, w.weights, w.getWeightShape()); + std::string out = genVarName(nm); + if (m_nonspatial_fp32) { + std::string x32 = castFixed(block, x2d, "fp32", {-1, C}); + std::string w32 = castFixed(block, wName, "fp32", w.getWeightShape()); + auto sd = m_weight_dtype; + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + std::string o32 = genVarName(nm + "_f32"); + matmul(x32, w32, o32, {-1, total}, false, false); + m_weight_dtype = sd; + out = castFixed(block, o32, "fp16", {-1, total}); + } else { + matmul(x2d, wName, out, {-1, total}, false, false); + } + return out; + }; + std::string q2d = proj(desc.q_proj, prefix + "_q", qTotal); + std::string k2d = proj(desc.k_proj, prefix + "_k", kTotal); + std::string v2d = proj(desc.v_proj, prefix + "_v", vTotal); + auto toHeads = [&](const std::string& in2d, const std::string& nm, int nh, int hd) { + std::string r = genVarName(nm + "_r"); + reshape(in2d, r, {-1, seq, nh, hd}, {-1, seq, nh, hd}); + std::string t = genVarName(nm + "_t"); + transpose(r, t, {0, 2, 1, 3}, {-1, nh, seq, hd}); + return t; + }; + std::string qh = toHeads(q2d, prefix + "_qh", numHeads, qHeadDim); + std::string kh = toHeads(k2d, prefix + "_kh", numKVHeads, qHeadDim); + std::string vh = toHeads(v2d, prefix + "_vh", numKVHeads, vHeadDim); + + if (desc.use_rope) { + const int numPairs = qHeadDim / 2; + const int numPairsPerDim = numPairs / 2; + const int dimHalf = qHeadDim / 2; + auto applyRope = [&](const std::string& x, int nh, const std::string& tag) { + std::vector cosFull(static_cast(nh) * seq * qHeadDim, 0.0f); + std::vector sinFull(static_cast(nh) * seq * qHeadDim, 0.0f); + for (int h = 0; h < nh; h++) { + int kvh = (h * numKVHeads) / nh; + for (int xy = 0; xy < seq; xy++) { + int y = xy / W; + int x = xy % W; + for (int p = 0; p < numPairs; p++) { + float angle = 0.0f; + if (desc.learnable_rope) { + float fx = desc.rope_freqs[(kvh * numPairs + p) * 2 + 0]; + float fy = desc.rope_freqs[(kvh * numPairs + p) * 2 + 1]; + angle = static_cast(x) * fx + static_cast(y) * fy; + } else { + if (p < numPairsPerDim) { + float freq = 1.0f / std::pow(desc.rope_theta, static_cast(2 * p) / dimHalf); + angle = static_cast(y) * freq; + } else { + int pAdj = p - numPairsPerDim; + float freq = 1.0f / std::pow(desc.rope_theta, static_cast(2 * pAdj) / dimHalf); + angle = static_cast(x) * freq; + } + } + float c = std::cos(angle), s = std::sin(angle); + size_t base = (static_cast(h) * seq + xy) * qHeadDim + 2 * p; + cosFull[base] = c; cosFull[base + 1] = c; + sinFull[base] = s; sinFull[base + 1] = s; + } + } + } + std::vector R(static_cast(qHeadDim) * qHeadDim, 0.0f); + for (int p = 0; p < numPairs; p++) { + R[(2 * p) * qHeadDim + (2 * p + 1)] = 1.0f; + R[(2 * p + 1) * qHeadDim + (2 * p)] = -1.0f; + } + std::string cosName = prefix + "_" + tag + "_cos"; + std::string sinName = prefix + "_" + tag + "_sin"; + std::string rName = prefix + "_" + tag + "_R"; + addConstOp(block, cosName, cosFull, {1, nh, seq, qHeadDim}); + addConstOp(block, sinName, sinFull, {1, nh, seq, qHeadDim}); + // Rank-4 [1,1,qd,qd] so matmul batch dims broadcast cleanly against [B,nh,seq,qd]. + addConstOp(block, rName, R, {1, 1, qHeadDim, qHeadDim}); + std::string rotated = genVarName(prefix + "_" + tag + "_rot"); + matmul(x, rName, rotated, {-1, nh, seq, qHeadDim}, false, false); + std::string xc = genVarName(prefix + "_" + tag + "_xc"); + binary("mul", x, cosName, xc, {-1, nh, seq, qHeadDim}); + std::string rs = genVarName(prefix + "_" + tag + "_rs"); + binary("mul", rotated, sinName, rs, {-1, nh, seq, qHeadDim}); + std::string out = genVarName(prefix + "_" + tag + "_rope"); + binary("add", xc, rs, out, {-1, nh, seq, qHeadDim}); + return out; + }; + qh = applyRope(qh, numHeads, "q"); + kh = applyRope(kh, numKVHeads, "k"); + } + + // GQA: when numKVHeads < numHeads, repeat each KV head groupSize times along the head + // axis (axis 1) so query head h consumes kv head (h / groupSize). RoPE has already been + // applied above to the unexpanded kh (kh = applyRope(kh, numKVHeads, "k")), mirroring the + // GPU path (metallayers.swift repeatKVHeads runs AFTER applyRope). We slice each KV head + // and concat its copies consecutively, so the resulting head index is kv*groupSize + g; + // query head h then maps to kv = h/groupSize == (h*numKVHeads)/numHeads (exact divisor, + // the same formula the qh RoPE table uses) == Eigen's kvh = h/kvGroupSize. slice_by_size + + // concat (not reshape+broadcast) avoids the dynamic -1 batch broadcast pitfall, same as the + // GPU code. The repeat is required so the scores (qh@kh^T) and attnOut (attn@vh) matmuls see + // matching [B,numHeads,...] batch dims instead of numHeads vs numKVHeads (no broadcast). + if (numKVHeads != numHeads) { + const int groupSize = numHeads / numKVHeads; + auto repeatKVHeads = [&](const std::string& x, const std::string& tag, int headDim) { + std::vector parts; + parts.reserve(static_cast(numKVHeads) * groupSize); + for (int kv = 0; kv < numKVHeads; kv++) { + for (int g = 0; g < groupSize; g++) { + std::string part = genVarName(prefix + "_" + tag + "_slc"); + std::string beginName = part + "_begin", sizeName = part + "_size"; + addIntArrayConstOp(block, beginName, {0, kv, 0, 0}); + addIntArrayConstOp(block, sizeName, {-1, 1, seq, headDim}); + auto* sop = block->add_operations(); + sop->set_type("slice_by_size"); + (*sop->mutable_inputs())["x"].add_arguments()->set_name(x); + (*sop->mutable_inputs())["begin"].add_arguments()->set_name(beginName); + (*sop->mutable_inputs())["size"].add_arguments()->set_name(sizeName); + setShape(sop, part, {-1, 1, seq, headDim}); + parts.push_back(part); + } + } + std::string out = genVarName(prefix + "_" + tag + "_exp"); + std::string axisName = out + "_axis", interleaveName = out + "_interleave"; + addIntScalarConstOp(block, axisName, 1); + addBoolScalarConstOp(block, interleaveName, false); + auto* cop = block->add_operations(); + cop->set_type("concat"); + auto& cin = *cop->mutable_inputs(); + for (const std::string& part : parts) + cin["values"].add_arguments()->set_name(part); + cin["axis"].add_arguments()->set_name(axisName); + cin["interleave"].add_arguments()->set_name(interleaveName); + setShape(cop, out, {-1, numHeads, seq, headDim}); + return out; + }; + kh = repeatKVHeads(kh, "khrep", qHeadDim); + vh = repeatKVHeads(vh, "vhrep", vHeadDim); + } + + std::string scores = genVarName(prefix + "_scores"); + matmul(qh, kh, scores, {-1, numHeads, seq, seq}, false, true); + std::string scaleName = prefix + "_scale"; + addFloatScalarConstOp(block, scaleName, 1.0f / std::sqrt(static_cast(qHeadDim))); + std::string scaled = genVarName(prefix + "_sc"); + binary("mul", scores, scaleName, scaled, {-1, numHeads, seq, seq}); + + // mask [B,1,H,W] -> [B,1,1,seq] directly (contiguous reshape; H,W already trailing so the + // row-major flatten gives seq index xy=y*W+x). No transpose -> avoids the reshape-after- + // transpose issue, and is also correct for non-full boards. + std::string maskSeq = genVarName(prefix + "_mseq"); + reshape(mask, maskSeq, {-1, 1, 1, seq}, {-1, 1, 1, seq}); + std::string oneName = prefix + "_one"; + addFloatScalarConstOp(block, oneName, 1.0f); + std::string mm1 = genVarName(prefix + "_mm1"); + binary("sub", maskSeq, oneName, mm1, {-1, 1, 1, seq}); + // Use an FP16-safe magnitude: 1e9 overflows FP16 to +inf, and for valid keys + // (maskSeq-1 == 0) the product 0 * inf becomes NaN, poisoning the whole softmax. + // 1e4 is well within FP16 range and exp(score - 1e4) still underflows to 0. + std::string bigName = prefix + "_big"; + addFloatScalarConstOp(block, bigName, 1.0e4f); + std::string keyBias = genVarName(prefix + "_kb"); + binary("mul", mm1, bigName, keyBias, {-1, 1, 1, seq}); + std::string scoresMasked = genVarName(prefix + "_scm"); + binary("add", scaled, keyBias, scoresMasked, {-1, numHeads, seq, seq}); + + std::string attn = genVarName(prefix + "_attn"); + { + std::string axisName = attn + "_axis"; + addIntScalarConstOp(block, axisName, 3); + auto* op = block->add_operations(); + op->set_type("softmax"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(scoresMasked); + (*op->mutable_inputs())["axis"].add_arguments()->set_name(axisName); + setShape(op, attn, {-1, numHeads, seq, seq}); + } + + std::string attnOut = genVarName(prefix + "_ao"); + matmul(attn, vh, attnOut, {-1, numHeads, seq, vHeadDim}, false, false); + + // Output projection, done per-head to avoid reshape-after-transpose: CoreML's reshape + // ignores an immediately-preceding transpose, so merging [head,dim]->channels after a + // transpose scrambles the data. Instead slice each head from attnOut (head is the + // contiguous axis 1), reshape (leading-merge only), matmul its weight slice, and sum. + // out[b,s,c] = sum_h sum_d attnOut[b,h,s,d] * outProj.weights[(h*vHeadDim+d)*outC + c] + const int outC = desc.out_proj.out_channels; + std::string proj2d; + for (int h = 0; h < numHeads; h++) { + std::string aoh = genVarName(prefix + "_aoh"); + { + std::string beginName = aoh + "_begin", sizeName = aoh + "_size"; + addIntArrayConstOp(block, beginName, {0, h, 0, 0}); + addIntArrayConstOp(block, sizeName, {-1, 1, seq, vHeadDim}); + auto* op = block->add_operations(); + op->set_type("slice_by_size"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(attnOut); + (*op->mutable_inputs())["begin"].add_arguments()->set_name(beginName); + (*op->mutable_inputs())["size"].add_arguments()->set_name(sizeName); + setShape(op, aoh, {-1, 1, seq, vHeadDim}); + } + std::string aoh2d = genVarName(prefix + "_aoh2d"); + reshape(aoh, aoh2d, {-1, vHeadDim}, {-1, vHeadDim}); // [B*seq, vHeadDim] + std::string wh = prefix + "_ow" + std::to_string(h); + std::vector whData(static_cast(vHeadDim) * outC); + for (int d = 0; d < vHeadDim; d++) + for (int c = 0; c < outC; c++) + whData[d * outC + c] = desc.out_proj.weights[static_cast(h * vHeadDim + d) * outC + c]; + addConstOp(block, wh, whData, {vHeadDim, outC}); + std::string contrib = genVarName(prefix + "_contrib"); + matmul(aoh2d, wh, contrib, {-1, outC}, false, false); + if (h == 0) { + proj2d = contrib; + } else { + std::string acc = genVarName(prefix + "_acc"); + binary("add", proj2d, contrib, acc, {-1, outC}); + proj2d = acc; + } + } + std::string projNHWC = genVarName(prefix + "_pnhwc"); + reshape(proj2d, projNHWC, {-1, H, W, C}, {-1, H, W, C}); + std::string projNCHW = genVarName(prefix + "_pnchw"); + transpose(projNHWC, projNCHW, {0, 3, 1, 2}, {-1, C, H, W}); + std::string maskedOut = genVarName(prefix + "_masked"); + binary("mul", projNCHW, mask, maskedOut, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + binary("add", input, maskedOut, out, {-1, C, H, W}); + return out; +} + +std::string MILBuilder::buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerFFNBlockDesc& desc, + const std::string& mask, + const std::string& prefix) { + const int C = desc.num_channels; + const int ffn = desc.ffn_channels; + const int H = m_board_y_size, W = m_board_x_size; + + if (!desc.use_swiglu) { + throw std::runtime_error(desc.name + ": non-SwiGLU transformer FFN not supported in CoreML backend"); + } + + auto reshape = [&](const std::string& in, const std::string& out, const std::vector& shapeVals, + const std::vector& dims) { + std::string shapeName = out + "_shape"; + addIntArrayConstOp(block, shapeName, shapeVals); + auto* op = block->add_operations(); + op->set_type("reshape"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["shape"].add_arguments()->set_name(shapeName); + setShape(op, out, dims); + }; + auto transpose = [&](const std::string& in, const std::string& out, const std::vector& perm, + const std::vector& dims) { + std::string permName = out + "_perm"; + addIntArrayConstOp(block, permName, perm); + auto* op = block->add_operations(); + op->set_type("transpose"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(in); + (*op->mutable_inputs())["perm"].add_arguments()->set_name(permName); + setShape(op, out, dims); + }; + auto matmul = [&](const std::string& x, const std::string& y, const std::string& out, + const std::vector& dims) { + std::string txName = out + "_tx", tyName = out + "_ty"; + addBoolScalarConstOp(block, txName, false); + addBoolScalarConstOp(block, tyName, false); + auto* op = block->add_operations(); + op->set_type("matmul"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + (*op->mutable_inputs())["transpose_x"].add_arguments()->set_name(txName); + (*op->mutable_inputs())["transpose_y"].add_arguments()->set_name(tyName); + setShape(op, out, dims); + }; + auto binary = [&](const std::string& type, const std::string& x, const std::string& y, + const std::string& out, const std::vector& dims) { + auto* op = block->add_operations(); + op->set_type(type); + (*op->mutable_inputs())["x"].add_arguments()->set_name(x); + (*op->mutable_inputs())["y"].add_arguments()->set_name(y); + setShape(op, out, dims); + }; + + std::string normed = addTransformerRMSNorm(block, input, desc.pre_ln, mask, prefix + "_ln"); + std::string nhwc = genVarName(prefix + "_nhwc"); + transpose(normed, nhwc, {0, 2, 3, 1}, {-1, H, W, C}); + std::string x2d = genVarName(prefix + "_x2d"); + reshape(nhwc, x2d, {-1, C}, {-1, C}); + + // FFN matmuls in FP32 (weights cast up at runtime, stored fp16) — KataGo's FP16 convention is + // spatial(convs)=FP16, non-spatial(matmuls)=FP32 (see openclbackend.cpp). The ANE accumulates + // FP16 matmuls in FP16, which loses too much precision over C/ffn; run them in FP32 instead. + std::string w1 = prefix + "_w1"; + addConstOp(block, w1, desc.linear1.weights, desc.linear1.getWeightShape()); + std::string wg = prefix + "_wg"; + addConstOp(block, wg, desc.linear_gate.weights, desc.linear_gate.getWeightShape()); + std::string w2 = prefix + "_w2"; + addConstOp(block, w2, desc.linear2.weights, desc.linear2.getWeightShape()); + + auto savedDtype = m_weight_dtype; + std::string mx2d = x2d, mw1 = w1, mwg = wg, mw2 = w2; + if (m_nonspatial_fp32) { + mx2d = castFixed(block, x2d, "fp32", {-1, C}); + mw1 = castFixed(block, w1, "fp32", desc.linear1.getWeightShape()); + mwg = castFixed(block, wg, "fp32", desc.linear_gate.getWeightShape()); + mw2 = castFixed(block, w2, "fp32", desc.linear2.getWeightShape()); + m_weight_dtype = CoreML::Specification::MILSpec::DataType::FLOAT32; + } + std::string a = genVarName(prefix + "_a"); + matmul(mx2d, mw1, a, {-1, ffn}); + std::string g = genVarName(prefix + "_g"); + matmul(mx2d, mwg, g, {-1, ffn}); + + std::string sig = genVarName(prefix + "_sig"); + { + auto* op = block->add_operations(); + op->set_type("sigmoid"); + (*op->mutable_inputs())["x"].add_arguments()->set_name(a); + setShape(op, sig, {-1, ffn}); + } + std::string siluA = genVarName(prefix + "_silu"); + binary("mul", a, sig, siluA, {-1, ffn}); + std::string h = genVarName(prefix + "_h"); + binary("mul", siluA, g, h, {-1, ffn}); + + std::string oCore = genVarName(prefix + "_o"); + matmul(h, mw2, oCore, {-1, C}); + std::string o = oCore; + if (m_nonspatial_fp32) { + m_weight_dtype = savedDtype; + o = castFixed(block, oCore, "fp16", {-1, C}); + } + + std::string oNHWC = genVarName(prefix + "_onhwc"); + reshape(o, oNHWC, {-1, H, W, C}, {-1, H, W, C}); + std::string oNCHW = genVarName(prefix + "_onchw"); + transpose(oNHWC, oNCHW, {0, 3, 1, 2}, {-1, C, H, W}); + std::string maskedOut = genVarName(prefix + "_masked"); + binary("mul", oNCHW, mask, maskedOut, {-1, C, H, W}); + std::string out = genVarName(prefix + "_out"); + binary("add", input, maskedOut, out, {-1, C, H, W}); + return out; +} + std::string MILBuilder::buildTrunk(CoreML::Specification::MILSpec::Block* block, const std::string& spatial_input, const std::string& global_input, @@ -1747,12 +2596,23 @@ std::string MILBuilder::buildTrunk(CoreML::Specification::MILSpec::Block* block, } else if (entry.block_kind == NESTED_BOTTLENECK_BLOCK_KIND) { const auto& block_desc = std::get(*entry.block); x = buildNestedBottleneckBlock(block, x, block_desc, mask, prefix); + } else if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + const auto& block_desc = std::get(*entry.block); + x = buildTransformerAttentionBlock(block, x, block_desc, mask, prefix); + } else if (entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + const auto& block_desc = std::get(*entry.block); + x = buildTransformerFFNBlock(block, x, block_desc, mask, prefix); } } // Trunk tip - std::string trunk_out = genVarName("trunk_tip"); - addBatchNormActivationOps(block, x, trunk.trunk_tip_bn, trunk.trunk_tip_activation, mask, trunk_out); + std::string trunk_out; + if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) { + trunk_out = genVarName("trunk_tip"); + addBatchNormActivationOps(block, x, trunk.trunk_tip_bn, trunk.trunk_tip_activation, mask, trunk_out); + } else { + trunk_out = addTrunkRMSNorm(block, x, trunk.trunk_tip_rms_norm, trunk.trunk_tip_activation, mask, "trunk_tip_rms"); + } return trunk_out; } @@ -1814,9 +2674,11 @@ std::string MILBuilder::buildGlobalPoolingResidualBlock(CoreML::Specification::M std::string gpool_bn_out = genVarName(prefix + "_gpool_bn"); addBatchNormActivationOps(block, gpool_conv_out, block_desc.gpool_bn, block_desc.gpool_activation, mask, gpool_bn_out); - // Global pooling + // Global pooling (FP32 when m_nonspatial_fp32 -- see addGlobalPoolingFp32). Feeds a bias back + // into the whole trunk, so the FP16 spatial sum must not lose precision for wide trunks. std::string gpool_features = genVarName(prefix + "_gpool_features"); - addGlobalPoolingOps(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features); + addGlobalPoolingFp32(block, gpool_bn_out, mask, block_desc.gpool_conv.out_channels, gpool_features, + /*valueVariant=*/false); // Project to bias std::string gpool_bias = genVarName(prefix + "_gpool_bias"); @@ -1898,6 +2760,12 @@ std::string MILBuilder::buildNestedBottleneckBlock(CoreML::Specification::MILSpe } else if (entry.block_kind == GLOBAL_POOLING_BLOCK_KIND) { const auto& nested = std::get(*entry.block); x = buildGlobalPoolingResidualBlock(block, x, nested, mask, nested_prefix); + } else if (entry.block_kind == TRANSFORMER_ATTENTION_BLOCK_KIND) { + const auto& nested = std::get(*entry.block); + x = buildTransformerAttentionBlock(block, x, nested, mask, nested_prefix); + } else if (entry.block_kind == TRANSFORMER_FFN_BLOCK_KIND) { + const auto& nested = std::get(*entry.block); + x = buildTransformerFFNBlock(block, x, nested, mask, nested_prefix); } } @@ -1942,9 +2810,9 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, std::string g1 = genVarName("policy_g1"); addBatchNormActivationOps(block, g1_conv, ph.g1_bn, ph.g1_activation, mask, g1); - // Global pooling on G1 + // Global pooling on G1 (FP32 when m_nonspatial_fp32; feeds the policy bias / policyKLDiv). std::string g1_pooled = genVarName("policy_g1_pool"); - addGlobalPoolingOps(block, g1, mask, ph.g1_conv.out_channels, g1_pooled); + addGlobalPoolingFp32(block, g1, mask, ph.g1_conv.out_channels, g1_pooled, /*valueVariant=*/false); // Project to spatial bias std::string gpool_bias = genVarName("policy_gpool_bias"); @@ -2002,6 +2870,8 @@ void MILBuilder::buildPolicyHead(CoreML::Specification::MILSpec::Block* block, setTensorOutput2D(op, pass_activated, ph.gpool_to_pass_mul.out_channels); } else if (ph.pass_activation->activation_type == ActivationType::Mish) { addMishOps(block, pass_biased, pass_activated, 2, ph.gpool_to_pass_mul.out_channels); + } else if (ph.pass_activation->activation_type == ActivationType::Silu) { + addSiluOps(block, pass_biased, pass_activated, 2, ph.gpool_to_pass_mul.out_channels); } else { pass_activated = pass_biased; } @@ -2032,9 +2902,9 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, std::string v1 = genVarName("value_v1"); addBatchNormActivationOps(block, v1_conv, vh.v1_bn, vh.v1_activation, mask, v1); - // Global pooling (value head version) + // Global pooling (value head version; FP32 when m_nonspatial_fp32). std::string v1_pooled = genVarName("value_v1_pool"); - addGlobalPoolingValueOps(block, v1, mask, vh.v1_conv.out_channels, v1_pooled); + addGlobalPoolingFp32(block, v1, mask, vh.v1_conv.out_channels, v1_pooled, /*valueVariant=*/true); // V2: linear + activation (fused matmul+bias -> linear) std::string v2_bias = genVarName("value_v2_bias"); @@ -2049,6 +2919,8 @@ void MILBuilder::buildValueHead(CoreML::Specification::MILSpec::Block* block, setTensorOutput2D(op, v2, vh.v2_mul.out_channels); } else if (vh.v2_activation.activation_type == ActivationType::Mish) { addMishOps(block, v2_bias, v2, 2, vh.v2_mul.out_channels); + } else if (vh.v2_activation.activation_type == ActivationType::Silu) { + addSiluOps(block, v2_bias, v2, 2, vh.v2_mul.out_channels); } else { v2 = v2_bias; } @@ -2085,6 +2957,8 @@ std::string MILBuilder::buildSGFMetadataEncoder(CoreML::Specification::MILSpec:: setTensorOutput2D(op, act1, encoder.mul1.out_channels); } else if (encoder.act1.activation_type == ActivationType::Mish) { addMishOps(block, bias1, act1, 2, encoder.mul1.out_channels); + } else if (encoder.act1.activation_type == ActivationType::Silu) { + addSiluOps(block, bias1, act1, 2, encoder.mul1.out_channels); } else { // Identity activation - create identity op to preserve type information auto* op = block->add_operations(); @@ -2107,6 +2981,8 @@ std::string MILBuilder::buildSGFMetadataEncoder(CoreML::Specification::MILSpec:: setTensorOutput2D(op, act2, encoder.mul2.out_channels); } else if (encoder.act2.activation_type == ActivationType::Mish) { addMishOps(block, bias2, act2, 2, encoder.mul2.out_channels); + } else if (encoder.act2.activation_type == ActivationType::Silu) { + addSiluOps(block, bias2, act2, 2, encoder.mul2.out_channels); } else { // Identity activation - create identity op to preserve type information auto* op = block->add_operations(); diff --git a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp index 042f9fc16..e38afb05e 100644 --- a/cpp/external/katagocoreml/src/builder/MILBuilder.hpp +++ b/cpp/external/katagocoreml/src/builder/MILBuilder.hpp @@ -43,6 +43,16 @@ class MILBuilder { bool m_optimize_identity_mask; bool m_use_fp16; bool m_use_fp16_io; + // FP32-in-FP16-mode escalations all run off the FP16-only ANE, so they apply ONLY to transformer + // trunks (attention widens activation range, overflowing FP16 conv/matmul/pooling accumulation). + // Plain convnets run pure FP16 on the ANE -- the long-standing pre-tier path, verified to pass + // testgpuerror (b18c384nbt) and ~2.3x faster than forcing their per-block global pooling to FP32. + // For transformers: narrow trunks (<256) build fully FP32; wider ones use non-spatial FP32 (matmuls + + // pooling) plus, for very wide trunks (>=320), conv FP32. RMSNorm reductions: FP32 when m_use_fp16. + static constexpr int CONV_FP32_MIN_TRUNK_CHANNELS = 320; // transformer convs run FP32 at/above this width + static constexpr int FULL_FP32_MAX_TRUNK_CHANNELS = 256; // transformer trunks below this build fully FP32 + bool m_nonspatial_fp32 = false; // = m_use_fp16 && hasTransformer (matmuls + global pooling) + bool m_conv_fp32 = false; // = m_use_fp16 && hasTransformer && trunk_channels >= CONV_FP32_MIN_... int m_min_batch_size; int m_max_batch_size; CoreML::Specification::MILSpec::DataType m_weight_dtype; @@ -102,6 +112,23 @@ class MILBuilder { const std::string& dtype, const std::vector& shape); + // Cast to a tensor with FULLY-specified dims (no forced batch dim like addCastOp). Use for + // weight tensors (fixed [in,out] dims) when running an otherwise-FP16 op in FP32. Returns the + // new tensor name. dims use -1 for an unknown/batch dim, >=0 for a constant dim. + std::string castFixed(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& dtype, + const std::vector& dims); + + // Emit global pooling, running it in FP32 when m_nonspatial_fp32 (cast input/mask up, pool, + // cast the pooled features back to FP16). valueVariant selects the value-head pooling variant. + void addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& mask, + int channels, + const std::string& output, + bool valueVariant); + void addConvOp(CoreML::Specification::MILSpec::Block* block, const std::string& input, const ConvLayerDesc& layer, @@ -120,6 +147,44 @@ class MILBuilder { int rank, int channels); + void addSiluOps(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const std::string& output, + int rank, + int channels); + + // Generic output-shape setter: dims with -1 entries become unknown/dynamic dimensions. + void setShape(CoreML::Specification::MILSpec::Operation* op, + const std::string& name, + const std::vector& dims); + + // Lightweight transformer RMSNorm (weight only, per-position over channels). NCHW in/out. + std::string addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerRMSNormDesc& desc, + const std::string& mask, + const std::string& prefix); + + // Full RMSNorm at trunk tip: gamma/beta, spatial or per-position, fused activation. NCHW in/out. + std::string addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const RMSNormLayerDesc& desc, + const ActivationLayerDesc& act, + const std::string& mask, + const std::string& prefix); + + std::string buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerAttentionBlockDesc& block_desc, + const std::string& mask, + const std::string& prefix); + + std::string buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block, + const std::string& input, + const TransformerFFNBlockDesc& block_desc, + const std::string& mask, + const std::string& prefix); + void addGlobalPoolingOps(CoreML::Specification::MILSpec::Block* block, const std::string& input, const std::string& mask, diff --git a/cpp/external/katagocoreml/src/builder/Operations.cpp b/cpp/external/katagocoreml/src/builder/Operations.cpp index c0c036292..1c625acdd 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.cpp +++ b/cpp/external/katagocoreml/src/builder/Operations.cpp @@ -14,12 +14,14 @@ KataGoOps::KataGoOps(int board_x_size, int board_y_size, bool optimize_identity_ std::string KataGoOps::registerWeight(const std::string& name, const std::vector& data, - const std::vector& shape) { + const std::vector& shape, + bool is_fp32) { WeightEntry entry; entry.name = name; entry.data = data; entry.shape = shape; entry.blob_offset = 0; // Will be set during serialization + entry.is_fp32 = is_fp32; m_weights.push_back(std::move(entry)); return name; } diff --git a/cpp/external/katagocoreml/src/builder/Operations.hpp b/cpp/external/katagocoreml/src/builder/Operations.hpp index 3fc72ad88..a9d2a1466 100644 --- a/cpp/external/katagocoreml/src/builder/Operations.hpp +++ b/cpp/external/katagocoreml/src/builder/Operations.hpp @@ -16,6 +16,8 @@ struct WeightEntry { std::vector data; std::vector shape; uint64_t blob_offset = 0; // Set during serialization + bool is_fp32 = false; // Store as FP32 (set when the const was declared FP32, e.g. inside an + // FP32 sub-region of an otherwise-FP16 model). Else stored per global mode. }; /// Precomputed constants for identity mask optimization @@ -51,10 +53,11 @@ class KataGoOps { /// Get precomputed mask constants const MaskConstants& getMaskConstants() const { return m_mask_constants; } - /// Register a weight tensor and return its reference name + /// Register a weight tensor and return its reference name. is_fp32 marks it for FP32 storage. std::string registerWeight(const std::string& name, const std::vector& data, - const std::vector& shape); + const std::vector& shape, + bool is_fp32 = false); /// Get all registered weights const std::vector& getWeights() const { return m_weights; } diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp index 68f1a0e56..20d2dee36 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.cpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.cpp @@ -315,6 +315,8 @@ ActivationLayerDesc KataGoParser::parseActivationLayer(int model_version) { layer.activation_type = ActivationType::ReLU; } else if (activation_str == "ACTIVATION_MISH") { layer.activation_type = ActivationType::Mish; + } else if (activation_str == "ACTIVATION_SILU") { + layer.activation_type = ActivationType::Silu; } else { throw std::runtime_error("Unknown activation type: " + activation_str); } @@ -420,6 +422,98 @@ static void checkBlockChannels(const std::string& block_name, const std::string& } } +TransformerRMSNormDesc KataGoParser::parseTransformerRMSNorm() { + TransformerRMSNormDesc layer; + layer.name = readString(); + layer.num_channels = readInt(); + layer.epsilon = readFloat(); + if (layer.num_channels < 1) { + throw std::runtime_error(layer.name + ": transformer rmsnorm numChannels must be >= 1"); + } + layer.weight = readFloats(layer.num_channels, layer.name + "/weight"); + return layer; +} + +RMSNormLayerDesc KataGoParser::parseRMSNormLayer() { + RMSNormLayerDesc layer; + layer.name = readString(); + layer.num_channels = readInt(); + layer.epsilon = readFloat(); + layer.spatial = (readInt() != 0); + layer.cgroup_size = readInt(); + if (layer.num_channels < 1) { + throw std::runtime_error(layer.name + ": rmsnorm numChannels must be >= 1"); + } + if (layer.cgroup_size != 0) { + throw std::runtime_error(layer.name + ": grouped spatial RMSNorm is not supported"); + } + layer.gamma = readFloats(layer.num_channels, layer.name + "/gamma"); + layer.beta = readFloats(layer.num_channels, layer.name + "/beta"); + return layer; +} + +TransformerAttentionBlockDesc KataGoParser::parseTransformerAttentionBlock(int model_version) { + TransformerAttentionBlockDesc block; + block.name = readString(); + block.num_heads = readInt(); + block.num_kv_heads = readInt(); + block.q_head_dim = readInt(); + block.v_head_dim = readInt(); + block.use_rope = (readInt() != 0); + block.learnable_rope = (readInt() != 0); + + if (block.num_heads < 1 || block.num_kv_heads < 1 || (block.num_heads % block.num_kv_heads != 0)) { + throw std::runtime_error(block.name + ": invalid numHeads/numKVHeads"); + } + if (block.use_rope && (block.q_head_dim % 2 != 0)) { + throw std::runtime_error(block.name + ": qHeadDim must be even when RoPE is used"); + } + + block.pre_ln = parseTransformerRMSNorm(); + block.q_proj = parseMatMulLayer(); + block.k_proj = parseMatMulLayer(); + block.v_proj = parseMatMulLayer(); + block.out_proj = parseMatMulLayer(); + + if (block.use_rope) { + if (block.learnable_rope) { + readString(); // ropeFreqs name + block.rope_num_kv_heads = readInt(); + block.rope_num_pairs = readInt(); + int rope_dim2 = readInt(); + if (block.rope_num_kv_heads != block.num_kv_heads || + block.rope_num_pairs != block.q_head_dim / 2 || rope_dim2 != 2) { + throw std::runtime_error(block.name + ": invalid learnable rope header"); + } + block.rope_freqs = readFloats( + static_cast(block.rope_num_kv_heads) * block.rope_num_pairs * 2, + block.name + "/rope_freqs"); + } else { + readString(); // ropeTheta name + block.rope_theta = readFloat(); + } + } + return block; +} + +TransformerFFNBlockDesc KataGoParser::parseTransformerFFNBlock(int model_version) { + TransformerFFNBlockDesc block; + block.name = readString(); + block.num_channels = readInt(); + block.ffn_channels = readInt(); + block.use_swiglu = (readInt() != 0); + if (block.num_channels < 1 || block.ffn_channels < 1) { + throw std::runtime_error(block.name + ": transformer ffn channels must be positive"); + } + block.pre_ln = parseTransformerRMSNorm(); + block.linear1 = parseMatMulLayer(); + if (block.use_swiglu) { + block.linear_gate = parseMatMulLayer(); + } + block.linear2 = parseMatMulLayer(); + return block; +} + std::vector KataGoParser::parseBlockStack(int model_version, int num_blocks, int trunk_num_channels) { std::vector blocks; blocks.reserve(num_blocks); @@ -449,6 +543,14 @@ std::vector KataGoParser::parseBlockStack(int model_version, int num desc.pre_bn.num_channels, desc.post_conv.out_channels, trunk_num_channels); entry.block = std::make_shared(std::move(desc)); + } else if (block_kind_name == "transformer_attention_block") { + entry.block_kind = TRANSFORMER_ATTENTION_BLOCK_KIND; + auto desc = parseTransformerAttentionBlock(model_version); + entry.block = std::make_shared(std::move(desc)); + } else if (block_kind_name == "transformer_ffn_block") { + entry.block_kind = TRANSFORMER_FFN_BLOCK_KIND; + auto desc = parseTransformerFFNBlock(model_version); + entry.block = std::make_shared(std::move(desc)); } else { throw std::runtime_error("Unknown block kind: " + block_kind_name); } @@ -506,15 +608,15 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version) } // Version >= 15 writes the trunk norm kind followed by 5 unused int parameters. - // This CoreML parser only supports the standard trunk norm kind (0 = BatchNorm/BiasMask); - // RMSNorm (used by transformer/rmsnorm models, kind != 0) is not implemented here, so reject it - // defensively rather than silently parsing it as standard norm and producing wrong outputs. + // Unlike upstream's CoreML parser (which rejects any non-standard norm), this fork + // implements RMSNorm, so we capture the kind here instead of throwing. The 5 trailing + // ints are reserved and still expected to be zero. if (model_version >= 15) { - int trunk_norm_kind = readInt(); - if (trunk_norm_kind != 0) { - throw std::runtime_error(trunk.name + ": unsupported trunk norm kind " + - std::to_string(trunk_norm_kind) + - " (this CoreML parser only supports standard trunk norm, not RMSNorm)"); + trunk.trunk_norm_kind = readInt(); + if (trunk.trunk_norm_kind != TRUNK_NORM_KIND_STANDARD && + trunk.trunk_norm_kind != TRUNK_NORM_KIND_RMSNORM) { + throw std::runtime_error(trunk.name + ": unknown/unsupported trunk norm kind " + + std::to_string(trunk.trunk_norm_kind)); } for (int i = 0; i < 5; i++) { int unused = readInt(); @@ -561,14 +663,24 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version) // Parse residual blocks trunk.blocks = parseBlockStack(model_version, trunk.num_blocks, trunk.trunk_num_channels); - trunk.trunk_tip_bn = parseBatchNormLayer(); - trunk.trunk_tip_activation = parseActivationLayer(model_version); - if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) { - throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" + - std::to_string(trunk.trunk_tip_bn.num_channels) + - ") != trunkNumChannels (" + - std::to_string(trunk.trunk_num_channels) + ")"); + if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) { + trunk.trunk_tip_bn = parseBatchNormLayer(); + if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) { + throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" + + std::to_string(trunk.trunk_tip_bn.num_channels) + + ") != trunkNumChannels (" + + std::to_string(trunk.trunk_num_channels) + ")"); + } + } else { + trunk.trunk_tip_rms_norm = parseRMSNormLayer(); + if (trunk.trunk_tip_rms_norm.num_channels != trunk.trunk_num_channels) { + throw std::runtime_error(trunk.name + ": trunkTipRMSNorm.numChannels (" + + std::to_string(trunk.trunk_tip_rms_norm.num_channels) + + ") != trunkNumChannels (" + + std::to_string(trunk.trunk_num_channels) + ")"); + } } + trunk.trunk_tip_activation = parseActivationLayer(model_version); return trunk; } diff --git a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp index a7d9f161c..9a00523d1 100644 --- a/cpp/external/katagocoreml/src/parser/KataGoParser.hpp +++ b/cpp/external/katagocoreml/src/parser/KataGoParser.hpp @@ -50,11 +50,15 @@ class KataGoParser { ActivationLayerDesc parseActivationLayer(int model_version); MatMulLayerDesc parseMatMulLayer(); MatBiasLayerDesc parseMatBiasLayer(); + TransformerRMSNormDesc parseTransformerRMSNorm(); + RMSNormLayerDesc parseRMSNormLayer(); // Block parsing functions ResidualBlockDesc parseResidualBlock(int model_version); GlobalPoolingResidualBlockDesc parseGlobalPoolingResidualBlock(int model_version); NestedBottleneckResidualBlockDesc parseNestedBottleneckBlock(int model_version, int trunk_num_channels); + TransformerAttentionBlockDesc parseTransformerAttentionBlock(int model_version); + TransformerFFNBlockDesc parseTransformerFFNBlock(int model_version); std::vector parseBlockStack(int model_version, int num_blocks, int trunk_num_channels); // Component parsing functions diff --git a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp index 2ac23a3da..e8fe861c8 100644 --- a/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp +++ b/cpp/external/katagocoreml/src/serializer/WeightSerializer.cpp @@ -15,7 +15,11 @@ size_t WeightSerializer::serialize(std::vector& weights, size_t total_bytes = 0; for (auto& entry : weights) { - if (use_fp16) { + // Per-weight precision: store FP16 only when the global mode is FP16 AND this weight was not + // declared FP32 (entry.is_fp32 marks consts inside an FP32 sub-region of an FP16 model), so + // stored bytes stay consistent with each const's declared dtype. + const bool store_fp16 = use_fp16 && !entry.is_fp32; + if (store_fp16) { // Convert FP32 weights to FP16 std::vector fp16_data(entry.data.size()); for (size_t i = 0; i < entry.data.size(); ++i) { diff --git a/cpp/external/katagocoreml/src/types/KataGoTypes.hpp b/cpp/external/katagocoreml/src/types/KataGoTypes.hpp index 147541a39..1074ad419 100644 --- a/cpp/external/katagocoreml/src/types/KataGoTypes.hpp +++ b/cpp/external/katagocoreml/src/types/KataGoTypes.hpp @@ -20,10 +20,15 @@ namespace katagocoreml { enum class ActivationType : int { Identity = 0, ReLU = 1, - Mish = 2 + Mish = 2, + Silu = 3 // MISH_SCALE8 = 12 is internal optimization, treated as Mish }; +/// Trunk normalization kind (matching KataGo's desc.h) +constexpr int TRUNK_NORM_KIND_STANDARD = 0; +constexpr int TRUNK_NORM_KIND_RMSNORM = 1; + // ============================================================================ // Block Kind Constants // ============================================================================ @@ -32,6 +37,8 @@ enum class ActivationType : int { constexpr int ORDINARY_BLOCK_KIND = 0; constexpr int GLOBAL_POOLING_BLOCK_KIND = 2; constexpr int NESTED_BOTTLENECK_BLOCK_KIND = 3; +constexpr int TRANSFORMER_ATTENTION_BLOCK_KIND = 4; +constexpr int TRANSFORMER_FFN_BLOCK_KIND = 5; // ============================================================================ // Layer Descriptors @@ -99,6 +106,25 @@ struct MatBiasLayerDesc { std::vector weights; // Shape: [num_channels] }; +/// Lightweight RMSNorm used inside transformer blocks (weight only, no bias). +struct TransformerRMSNormDesc { + std::string name; + int num_channels = 0; + float epsilon = 1e-6f; + std::vector weight; // Shape: [num_channels] +}; + +/// Full-featured RMSNorm (gamma/beta, spatial mode) used at the trunk tip. +struct RMSNormLayerDesc { + std::string name; + int num_channels = 0; + float epsilon = 1e-6f; + bool spatial = false; + int cgroup_size = 0; + std::vector gamma; // Shape: [num_channels] + std::vector beta; // Shape: [num_channels] +}; + // ============================================================================ // Block Descriptors // ============================================================================ @@ -107,12 +133,16 @@ struct MatBiasLayerDesc { struct ResidualBlockDesc; struct GlobalPoolingResidualBlockDesc; struct NestedBottleneckResidualBlockDesc; +struct TransformerAttentionBlockDesc; +struct TransformerFFNBlockDesc; /// Block descriptor variant using BlockDesc = std::variant< ResidualBlockDesc, GlobalPoolingResidualBlockDesc, - NestedBottleneckResidualBlockDesc + NestedBottleneckResidualBlockDesc, + TransformerAttentionBlockDesc, + TransformerFFNBlockDesc >; /// Block with its kind @@ -166,6 +196,38 @@ struct NestedBottleneckResidualBlockDesc { ConvLayerDesc post_conv; }; +/// Transformer self-attention block descriptor (pre-norm, multi-head, optional 2D RoPE, GQA). +struct TransformerAttentionBlockDesc { + std::string name; + int num_heads = 0; + int num_kv_heads = 0; + int q_head_dim = 0; + int v_head_dim = 0; + bool use_rope = false; + bool learnable_rope = false; + TransformerRMSNormDesc pre_ln; + MatMulLayerDesc q_proj; + MatMulLayerDesc k_proj; + MatMulLayerDesc v_proj; + MatMulLayerDesc out_proj; + int rope_num_kv_heads = 0; + int rope_num_pairs = 0; + std::vector rope_freqs; // learnable: (num_kv_heads, num_pairs, 2) flattened + float rope_theta = 0.0f; +}; + +/// Transformer feed-forward (SwiGLU) block descriptor. +struct TransformerFFNBlockDesc { + std::string name; + int num_channels = 0; + int ffn_channels = 0; + bool use_swiglu = false; + TransformerRMSNormDesc pre_ln; + MatMulLayerDesc linear1; + MatMulLayerDesc linear_gate; // only used when use_swiglu + MatMulLayerDesc linear2; +}; + // ============================================================================ // SGF Metadata Encoder (v15+) // ============================================================================ @@ -203,7 +265,9 @@ struct TrunkDesc { MatMulLayerDesc initial_matmul; std::optional sgf_metadata_encoder; std::vector blocks; + int trunk_norm_kind = TRUNK_NORM_KIND_STANDARD; BatchNormLayerDesc trunk_tip_bn; + RMSNormLayerDesc trunk_tip_rms_norm; ActivationLayerDesc trunk_tip_activation; }; diff --git a/cpp/neuralnet/metalbackend.cpp b/cpp/neuralnet/metalbackend.cpp index 786ef8290..10ee50b62 100644 --- a/cpp/neuralnet/metalbackend.cpp +++ b/cpp/neuralnet/metalbackend.cpp @@ -130,6 +130,8 @@ ActivationKind activationLayerDescToSwift(const ActivationLayerDesc* desc) { return ActivationKind::mish(); case ACTIVATION_MISH_SCALE8: return ActivationKind::identity(); // Metal/CoreML does not use scaled mish + case ACTIVATION_SILU: + return ActivationKind::silu(); case ACTIVATION_IDENTITY: return ActivationKind::identity(); default: @@ -217,6 +219,63 @@ SWNestedBottleneckResidualBlockDesc nestedBottleneckResidualBlockDescToSwift(con postConv); } +/// Convert a transformer RMSNorm description from C++ to Swift +SWTransformerRMSNormDesc transformerRMSNormDescToSwift(const TransformerRMSNormDesc* desc) { + return createSWTransformerRMSNormDesc( + desc->numChannels, + desc->epsilon, + (float*)desc->weight.data()); +} + +/// Convert a transformer attention block description from C++ to Swift +SWTransformerAttentionBlockDesc transformerAttentionBlockDescToSwift(const TransformerAttentionDesc* desc) { + SWTransformerRMSNormDesc preLN = transformerRMSNormDescToSwift(&desc->preLN); + SWMatMulLayerDesc qProj = matMulLayerDescToSwift(&desc->qProj); + SWMatMulLayerDesc kProj = matMulLayerDescToSwift(&desc->kProj); + SWMatMulLayerDesc vProj = matMulLayerDescToSwift(&desc->vProj); + SWMatMulLayerDesc outProj = matMulLayerDescToSwift(&desc->outProj); + float* ropeFreqs = desc->ropeFreqs.empty() ? nullptr : (float*)desc->ropeFreqs.data(); + + return createSWTransformerAttentionBlockDesc( + desc->numHeads, + desc->numKVHeads, + desc->qHeadDim, + desc->vHeadDim, + desc->useRope, + desc->learnableRope, + preLN, + qProj, + kProj, + vProj, + outProj, + desc->ropeNumKVHeads, + desc->ropeNumPairs, + ropeFreqs, + desc->ropeTheta); +} + +/// Convert a transformer FFN block description from C++ to Swift +SWTransformerFFNBlockDesc transformerFFNBlockDescToSwift(const TransformerFFNDesc* desc) { + // The Metal forward pass (metallayers.swift TransformerFFNBlock) only implements the SwiGLU path + // (SiLU(linear1) * gate); a non-SwiGLU model has no gate weights, so guard here as Eigen and CoreML + // do (eigenbackend.cpp / katagocoreml MILBuilder) instead of crashing on the empty gate descriptor. + if(!desc->useSwiGLU) + throw StringError(desc->name + ": non-SwiGLU transformer FFN not supported in Metal backend"); + SWTransformerRMSNormDesc preLN = transformerRMSNormDescToSwift(&desc->preLN); + SWMatMulLayerDesc linear1 = matMulLayerDescToSwift(&desc->linear1); + SWMatMulLayerDesc linearGate = matMulLayerDescToSwift(&desc->linearGate); + SWMatMulLayerDesc linear2 = matMulLayerDescToSwift(&desc->linear2); + + return createSWTransformerFFNBlockDesc( + desc->numChannels, + desc->ffnChannels, + desc->useSwiGLU, + preLN, + linear1, + linearGate, + linear2); +} + /// Convert residual blocks from C++ to Swift swift::Array residualBlocksToSwift(const vector>& blocks) { auto builder = createBlockDescriptorBuilder(); @@ -230,9 +289,12 @@ swift::Array residualBlocksToSwift(const vector sGFMetadataEncoderDescToSwift(const SG } /// Convert a trunk description from C++ to Swift +SWRMSNormLayerDesc rmsNormLayerDescToSwift(const RMSNormLayerDesc* desc) { + float* gamma = desc->gamma.empty() ? nullptr : (float*)desc->gamma.data(); + float* beta = desc->beta.empty() ? nullptr : (float*)desc->beta.data(); + return createSWRMSNormLayerDesc( + desc->numChannels, + desc->epsilon, + desc->spatial, + gamma, + beta); +} + SWTrunkDesc trunkDescToSwift(const TrunkDesc* trunk) { SWConvLayerDesc initialConv = convLayerDescToSwift(&trunk->initialConv); SWMatMulLayerDesc initialMatMul = matMulLayerDescToSwift(&trunk->initialMatMul); auto sgfMetadataEncoder = sGFMetadataEncoderDescToSwift(&trunk->sgfMetadataEncoder); auto swBlocks = residualBlocksToSwift(trunk->blocks); - if(trunk->trunkNormKind != TRUNK_NORM_KIND_STANDARD) - throw StringError("Trunk RMSNorm is not yet supported by the Metal backend"); SWBatchNormLayerDesc trunkTipBN = batchNormLayerDescToSwift(&trunk->trunkTipBN); + SWRMSNormLayerDesc trunkTipRMSNorm = rmsNormLayerDescToSwift(&trunk->trunkTipRMSNorm); ActivationKind trunkTipActivation = activationLayerDescToSwift(&trunk->trunkTipActivation); return createSWTrunkDesc( @@ -285,7 +357,9 @@ SWTrunkDesc trunkDescToSwift(const TrunkDesc* trunk) { initialMatMul, sgfMetadataEncoder, swBlocks, + trunk->trunkNormKind, trunkTipBN, + trunkTipRMSNorm, trunkTipActivation); } diff --git a/cpp/neuralnet/metallayers.swift b/cpp/neuralnet/metallayers.swift index bbd2255bc..e1324df96 100644 --- a/cpp/neuralnet/metallayers.swift +++ b/cpp/neuralnet/metallayers.swift @@ -76,6 +76,11 @@ extension MPSGraph { return mulTensor } + + /// SiLU / Swish activation: x * sigmoid(x). Numerically stable across FP16/FP32. + func silu(tensor: MPSGraphTensor) -> MPSGraphTensor { + return multiplication(tensor, sigmoid(with: tensor, name: nil), name: nil) + } } // MARK: - Input Shape Utilities @@ -358,6 +363,7 @@ public enum ActivationKind { case identity case relu case mish + case silu } /// A struct that represents a description of convolutional layer. @@ -487,6 +493,63 @@ public func createSWMatBiasLayerDesc( weights: weights) } +/// A lightweight RMSNorm description used inside transformer blocks (weight only, no bias). +public struct SWTransformerRMSNormDesc { + let numChannels: NSNumber + let epsilon: Float + let weight: UnsafeMutablePointer + + init(numChannels: NSNumber, epsilon: Float, weight: UnsafeMutablePointer) { + self.numChannels = numChannels + self.epsilon = epsilon + self.weight = weight + } +} + +public func createSWTransformerRMSNormDesc( + numChannels: Int32, + epsilon: Float, + weight: UnsafeMutablePointer +) -> SWTransformerRMSNormDesc { + return SWTransformerRMSNormDesc( + numChannels: numChannels as NSNumber, + epsilon: epsilon, + weight: weight) +} + +/// A full-featured RMSNorm description (gamma/beta, spatial mode), used at the trunk tip. +public struct SWRMSNormLayerDesc { + let numChannels: NSNumber + let epsilon: Float + let spatial: Bool + let gamma: UnsafeMutablePointer? + let beta: UnsafeMutablePointer? + + init(numChannels: NSNumber, epsilon: Float, spatial: Bool, + gamma: UnsafeMutablePointer?, beta: UnsafeMutablePointer?) { + self.numChannels = numChannels + self.epsilon = epsilon + self.spatial = spatial + self.gamma = gamma + self.beta = beta + } +} + +public func createSWRMSNormLayerDesc( + numChannels: Int32, + epsilon: Float, + spatial: Bool, + gamma: UnsafeMutablePointer?, + beta: UnsafeMutablePointer? +) -> SWRMSNormLayerDesc { + return SWRMSNormLayerDesc( + numChannels: numChannels as NSNumber, + epsilon: epsilon, + spatial: spatial, + gamma: gamma, + beta: beta) +} + // MARK: - Core Layers /// A class that represents a convolutional layer using MPSGraph @@ -612,6 +675,8 @@ struct ActivationLayer { resultTensor = graph.reLU(with: sourceTensor, name: nil) case .mish: resultTensor = graph.mish(tensor: sourceTensor) + case .silu: + resultTensor = graph.silu(tensor: sourceTensor) default: resultTensor = sourceTensor } @@ -987,6 +1052,140 @@ public func createSWNestedBottleneckResidualBlockDesc( postConv: postConv) } +public class SWTransformerAttentionBlockDesc: BlockDescriptor { + let numHeads: Int + let numKVHeads: Int + let qHeadDim: Int + let vHeadDim: Int + let useRope: Bool + let learnableRope: Bool + let preLN: SWTransformerRMSNormDesc + let qProj: SWMatMulLayerDesc + let kProj: SWMatMulLayerDesc + let vProj: SWMatMulLayerDesc + let outProj: SWMatMulLayerDesc + let ropeNumKVHeads: Int + let ropeNumPairs: Int + let ropeFreqs: UnsafeMutablePointer? // learnable: (numKVHeads, numPairs, 2) flattened + let ropeTheta: Float + + init( + numHeads: Int, + numKVHeads: Int, + qHeadDim: Int, + vHeadDim: Int, + useRope: Bool, + learnableRope: Bool, + preLN: SWTransformerRMSNormDesc, + qProj: SWMatMulLayerDesc, + kProj: SWMatMulLayerDesc, + vProj: SWMatMulLayerDesc, + outProj: SWMatMulLayerDesc, + ropeNumKVHeads: Int, + ropeNumPairs: Int, + ropeFreqs: UnsafeMutablePointer?, + ropeTheta: Float + ) { + self.numHeads = numHeads + self.numKVHeads = numKVHeads + self.qHeadDim = qHeadDim + self.vHeadDim = vHeadDim + self.useRope = useRope + self.learnableRope = learnableRope + self.preLN = preLN + self.qProj = qProj + self.kProj = kProj + self.vProj = vProj + self.outProj = outProj + self.ropeNumKVHeads = ropeNumKVHeads + self.ropeNumPairs = ropeNumPairs + self.ropeFreqs = ropeFreqs + self.ropeTheta = ropeTheta + } +} + +public func createSWTransformerAttentionBlockDesc( + numHeads: Int32, + numKVHeads: Int32, + qHeadDim: Int32, + vHeadDim: Int32, + useRope: Bool, + learnableRope: Bool, + preLN: SWTransformerRMSNormDesc, + qProj: SWMatMulLayerDesc, + kProj: SWMatMulLayerDesc, + vProj: SWMatMulLayerDesc, + outProj: SWMatMulLayerDesc, + ropeNumKVHeads: Int32, + ropeNumPairs: Int32, + ropeFreqs: UnsafeMutablePointer?, + ropeTheta: Float +) -> SWTransformerAttentionBlockDesc { + return SWTransformerAttentionBlockDesc( + numHeads: Int(numHeads), + numKVHeads: Int(numKVHeads), + qHeadDim: Int(qHeadDim), + vHeadDim: Int(vHeadDim), + useRope: useRope, + learnableRope: learnableRope, + preLN: preLN, + qProj: qProj, + kProj: kProj, + vProj: vProj, + outProj: outProj, + ropeNumKVHeads: Int(ropeNumKVHeads), + ropeNumPairs: Int(ropeNumPairs), + ropeFreqs: ropeFreqs, + ropeTheta: ropeTheta) +} + +public class SWTransformerFFNBlockDesc: BlockDescriptor { + let numChannels: Int + let ffnChannels: Int + let useSwiGLU: Bool + let preLN: SWTransformerRMSNormDesc + let linear1: SWMatMulLayerDesc + let linearGate: SWMatMulLayerDesc + let linear2: SWMatMulLayerDesc + + init( + numChannels: Int, + ffnChannels: Int, + useSwiGLU: Bool, + preLN: SWTransformerRMSNormDesc, + linear1: SWMatMulLayerDesc, + linearGate: SWMatMulLayerDesc, + linear2: SWMatMulLayerDesc + ) { + self.numChannels = numChannels + self.ffnChannels = ffnChannels + self.useSwiGLU = useSwiGLU + self.preLN = preLN + self.linear1 = linear1 + self.linearGate = linearGate + self.linear2 = linear2 + } +} + +public func createSWTransformerFFNBlockDesc( + numChannels: Int32, + ffnChannels: Int32, + useSwiGLU: Bool, + preLN: SWTransformerRMSNormDesc, + linear1: SWMatMulLayerDesc, + linearGate: SWMatMulLayerDesc, + linear2: SWMatMulLayerDesc +) -> SWTransformerFFNBlockDesc { + return SWTransformerFFNBlockDesc( + numChannels: Int(numChannels), + ffnChannels: Int(ffnChannels), + useSwiGLU: useSwiGLU, + preLN: preLN, + linear1: linear1, + linearGate: linearGate, + linear2: linear2) +} + public class BlockDescriptorBuilder { public var blockDescriptors: [BlockDescriptor] = [] @@ -1001,6 +1200,329 @@ public func createBlockDescriptorBuilder() -> BlockDescriptorBuilder { return BlockDescriptorBuilder() } +// MARK: - Transformer Layers + +/// Lightweight RMSNorm used inside transformer blocks (weight only, no bias). +/// Input/output are NCHW [B, C, H, W]. Normalizes across channels per spatial position, +/// scales by per-channel weight, and masks the output. +struct TransformerRMSNormLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerRMSNormDesc + ) { + let numChannels = descriptor.numChannels + let dataType = sourceTensor.dataType + + // meanSq over channel axis (1): [B,1,H,W] + let sq = graph.square(with: sourceTensor, name: nil) + let sumSq = graph.reductionSum(with: sq, axis: 1, name: nil) + let invC = graph.constant(1.0 / numChannels.doubleValue, dataType: dataType) + let meanSq = graph.multiplication(sumSq, invC, name: nil) + let epsTensor = graph.constant(Double(descriptor.epsilon), dataType: dataType) + let denom = graph.squareRoot(with: graph.addition(meanSq, epsTensor, name: nil), name: nil) + let normalized = graph.division(sourceTensor, denom, name: nil) + + // scale by per-channel weight [1, C, 1, 1] + let weightShape: [NSNumber] = [1, numChannels, 1, 1] + let weightData = Data(floatsNoCopy: descriptor.weight, shape: weightShape) + let weightTensor = graph.constant(weightData, shape: weightShape, dataType: dataType) + let scaled = graph.multiplication(normalized, weightTensor, name: nil) + + resultTensor = graph.multiplication(scaled, maskTensor, name: nil) + } +} + +/// Full-featured RMSNorm for the trunk tip: gamma/beta, spatial or per-position mode, and a +/// fused activation. Input/output are NCHW [B, C, H, W]. Mirrors the Eigen RMSNormLayer. +struct TrunkRMSNormLayer { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWRMSNormLayerDesc, + activationKind: ActivationKind + ) { + let dataType = sourceTensor.dataType + let numChannels = descriptor.numChannels + + // Zero invalid positions before accumulating sum of squares. + let masked = graph.multiplication(sourceTensor, maskTensor, name: nil) + let sq = graph.square(with: masked, name: nil) + + let meanSq: MPSGraphTensor + if descriptor.spatial { + // Normalize over channels AND valid spatial positions per batch element. + let sumSq = graph.reductionSum(with: sq, axes: [1, 2, 3], name: nil) // [B,1,1,1] + let count = graph.reductionSum(with: maskTensor, axes: [1, 2, 3], name: nil) // valid positions + let cTensor = graph.constant(numChannels.doubleValue, dataType: dataType) + let totalElts = graph.multiplication(count, cTensor, name: nil) + meanSq = graph.division(sumSq, totalElts, name: nil) + } else { + // Per-position normalization across channels. + let sumSq = graph.reductionSum(with: sq, axes: [1], name: nil) // [B,1,H,W] + let invC = graph.constant(1.0 / numChannels.doubleValue, dataType: dataType) + meanSq = graph.multiplication(sumSq, invC, name: nil) + } + + let epsTensor = graph.constant(Double(descriptor.epsilon), dataType: dataType) + let denom = graph.squareRoot(with: graph.addition(meanSq, epsTensor, name: nil), name: nil) + let normalized = graph.division(sourceTensor, denom, name: nil) + + let gammaShape: [NSNumber] = [1, numChannels, 1, 1] + let gammaTensor = graph.constant(Data(floatsNoCopy: descriptor.gamma!, shape: gammaShape), shape: gammaShape, dataType: dataType) + let betaTensor = graph.constant(Data(floatsNoCopy: descriptor.beta!, shape: gammaShape), shape: gammaShape, dataType: dataType) + let scaled = graph.addition(graph.multiplication(normalized, gammaTensor, name: nil), betaTensor, name: nil) + + let activated = ActivationLayer(graph: graph, sourceTensor: scaled, activationKind: activationKind).resultTensor + resultTensor = graph.multiplication(activated, maskTensor, name: nil) + } +} + +/// A transformer self-attention block (pre-norm, multi-head, optional 2D RoPE, GQA). +/// Mirrors the Eigen reference: RMSNorm -> Q/K/V projections -> RoPE -> scaled dot-product +/// attention with masked softmax -> output projection -> masked residual. +/// Tensors are NCHW [B, C, H, W]; spatial positions (H*W, ordered y*W+x) are the sequence. +struct TransformerAttentionBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerAttentionBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber + ) { + let dataType = sourceTensor.dataType + let numHeads = descriptor.numHeads + let numKVHeads = descriptor.numKVHeads + let qHeadDim = descriptor.qHeadDim + let vHeadDim = descriptor.vHeadDim + let nnX = nnXLen.intValue + let nnY = nnYLen.intValue + let seq = nnX * nnY + + // 1. RMSNorm (NCHW) + let normed = TransformerRMSNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preLN).resultTensor + + // To NHWC [B,H,W,C] so that reshape [-1, C] groups channels per position. + let normedNHWC = graph.transpose(normed, permutation: [0, 2, 3, 1], name: nil) + + // 2. Q/K/V projections via matmul over channels -> [B*seq, heads*dim] + let q = MatMulLayer(graph: graph, descriptor: descriptor.qProj, sourceTensor: normedNHWC).resultTensor + let k = MatMulLayer(graph: graph, descriptor: descriptor.kProj, sourceTensor: normedNHWC).resultTensor + let v = MatMulLayer(graph: graph, descriptor: descriptor.vProj, sourceTensor: normedNHWC).resultTensor + + // 3. reshape to [B, heads, seq, dim] + var qh = TransformerAttentionBlock.toHeads(graph, q, seq: seq, numHeads: numHeads, headDim: qHeadDim) + var kh = TransformerAttentionBlock.toHeads(graph, k, seq: seq, numHeads: numKVHeads, headDim: qHeadDim) + let vh = TransformerAttentionBlock.toHeads(graph, v, seq: seq, numHeads: numKVHeads, headDim: vHeadDim) + + // 4. RoPE on Q and K + if descriptor.useRope { + let numPairs = qHeadDim / 2 + // Q heads map to KV heads via kvh = h * numKVHeads / numHeads (matches Eigen). + let (qCos, qSin) = TransformerAttentionBlock.makeRopeTables( + graph, descriptor: descriptor, nHeads: numHeads, seq: seq, numPairs: numPairs, + nnX: nnX, nnY: nnY, qHeadDim: qHeadDim, dataType: dataType, + kvIndexForHead: { h in (h * numKVHeads) / numHeads }) + let (kCos, kSin) = TransformerAttentionBlock.makeRopeTables( + graph, descriptor: descriptor, nHeads: numKVHeads, seq: seq, numPairs: numPairs, + nnX: nnX, nnY: nnY, qHeadDim: qHeadDim, dataType: dataType, + kvIndexForHead: { h in h }) + qh = TransformerAttentionBlock.applyRope(graph, qh, cosT: qCos, sinT: qSin, + numHeads: numHeads, seq: seq, numPairs: numPairs) + kh = TransformerAttentionBlock.applyRope(graph, kh, cosT: kCos, sinT: kSin, + numHeads: numKVHeads, seq: seq, numPairs: numPairs) + } + + // GQA: if numKVHeads < numHeads, repeat KV heads so they align with query heads. + var khExp = kh + var vhExp = vh + if numKVHeads != numHeads { + let groupSize = numHeads / numKVHeads + khExp = TransformerAttentionBlock.repeatKVHeads(graph, kh, numKVHeads: numKVHeads, groupSize: groupSize, seq: seq, headDim: qHeadDim) + vhExp = TransformerAttentionBlock.repeatKVHeads(graph, vh, numKVHeads: numKVHeads, groupSize: groupSize, seq: seq, headDim: vHeadDim) + } + + // 5. scores = scale * Q @ K^T -> [B, heads, seq, seq] + let khT = graph.transpose(khExp, permutation: [0, 1, 3, 2], name: nil) + var scores = graph.matrixMultiplication(primary: qh, secondary: khT, name: nil) + let scale = graph.constant(1.0 / Double(qHeadDim).squareRoot(), dataType: dataType) + scores = graph.multiplication(scores, scale, name: nil) + + // Mask keys: add (maskKey - 1) * BIG so masked key columns get ~ -inf before softmax. + // maskTensor [B,1,H,W] -> [B,1,1,seq] + let maskNHWC = graph.transpose(maskTensor, permutation: [0, 2, 3, 1], name: nil) // [B,H,W,1] + let maskSeq = graph.reshape(maskNHWC, shape: [-1, 1, 1, seq as NSNumber], name: nil) + let one = graph.constant(1.0, dataType: dataType) + let big = graph.constant(1.0e9, dataType: dataType) + let keyBias = graph.multiplication(graph.subtraction(maskSeq, one, name: nil), big, name: nil) + scores = graph.addition(scores, keyBias, name: nil) + + // 6. softmax over key axis (last) + let attn = graph.softMax(with: scores, axis: 3, name: nil) + + // 7. out = attn @ V -> [B, heads, seq, vHeadDim] + let attnOut = graph.matrixMultiplication(primary: attn, secondary: vhExp, name: nil) + + // 8. back to [B*seq, heads*vHeadDim] + let outHeadsLast = graph.transpose(attnOut, permutation: [0, 2, 1, 3], name: nil) // [B,seq,heads,vHeadDim] + let outFlat = graph.reshape(outHeadsLast, shape: [-1, (numHeads * vHeadDim) as NSNumber], name: nil) + + // 9. output projection -> [B*seq, C] + let proj = MatMulLayer(graph: graph, descriptor: descriptor.outProj, sourceTensor: outFlat).resultTensor + + // 10. reshape to NHWC then NCHW + let outChannels = descriptor.outProj.outChannels + let projNHWC = graph.reshape(proj, shape: [-1, nnYLen, nnXLen, outChannels], name: nil) + let projNCHW = graph.transpose(projNHWC, permutation: [0, 3, 1, 2], name: nil) + + // 11. masked residual + let masked = graph.multiplication(projNCHW, maskTensor, name: nil) + resultTensor = graph.addition(sourceTensor, masked, name: nil) + } + + /// Reshape [B*seq, numHeads*headDim] -> [B, numHeads, seq, headDim]. + static func toHeads(_ graph: MPSGraph, _ x: MPSGraphTensor, seq: Int, numHeads: Int, headDim: Int) -> MPSGraphTensor { + let reshaped = graph.reshape(x, shape: [-1, seq as NSNumber, numHeads as NSNumber, headDim as NSNumber], name: nil) + return graph.transpose(reshaped, permutation: [0, 2, 1, 3], name: nil) + } + + /// Repeat each KV head groupSize times along the head axis: [B,numKVHeads,seq,dim] -> [B,numKVHeads*groupSize,seq,dim]. + static func repeatKVHeads(_ graph: MPSGraph, _ x: MPSGraphTensor, numKVHeads: Int, groupSize: Int, seq: Int, headDim: Int) -> MPSGraphTensor { + // Repeat each KV head groupSize times consecutively so query head h uses kv = h / groupSize, + // matching the Eigen reference (kvh = h / kvGroupSize). We slice each KV head and concat the + // copies along the head axis. Note: MPSGraph.broadcast(_:shape:) does NOT infer -1, so a + // reshape+broadcast approach with a dynamic batch dim triggers an NDArray INT_MAX assertion; + // slice+concat is shape-safe with no -1 broadcast. + var heads: [MPSGraphTensor] = [] + heads.reserveCapacity(numKVHeads * groupSize) + for kv in 0.. MPSGraphTensor { + let pairsShape: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber, 2] + let xPairs = graph.reshape(x, shape: pairsShape, name: nil) + let evenShape: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber] + let xEven = graph.reshape(graph.sliceTensor(xPairs, dimension: 4, start: 0, length: 1, name: nil), shape: evenShape, name: nil) + let xOdd = graph.reshape(graph.sliceTensor(xPairs, dimension: 4, start: 1, length: 1, name: nil), shape: evenShape, name: nil) + let outEven = graph.subtraction(graph.multiplication(xEven, cosT, name: nil), graph.multiplication(xOdd, sinT, name: nil), name: nil) + let outOdd = graph.addition(graph.multiplication(xEven, sinT, name: nil), graph.multiplication(xOdd, cosT, name: nil), name: nil) + let pairShape5: [NSNumber] = [-1, numHeads as NSNumber, seq as NSNumber, numPairs as NSNumber, 1] + let outEvenE = graph.reshape(outEven, shape: pairShape5, name: nil) + let outOddE = graph.reshape(outOdd, shape: pairShape5, name: nil) + let stacked = graph.concatTensors([outEvenE, outOddE], dimension: 4, name: nil) + return graph.reshape(stacked, shape: [-1, numHeads as NSNumber, seq as NSNumber, (numPairs * 2) as NSNumber], name: nil) + } + + /// Build RoPE cos/sin constant tensors of shape [1, nHeads, seq, numPairs]. + static func makeRopeTables(_ graph: MPSGraph, descriptor: SWTransformerAttentionBlockDesc, + nHeads: Int, seq: Int, numPairs: Int, nnX: Int, nnY: Int, qHeadDim: Int, + dataType: MPSDataType, kvIndexForHead: (Int) -> Int) -> (MPSGraphTensor, MPSGraphTensor) { + let count = nHeads * seq * numPairs + // Managed arrays (freed on return). Unlike the weight constants elsewhere, which point at + // C++-owned descriptor memory and so use floatsNoCopy, these tables have no persistent owner; + // we copy them into the Data below so MPSGraph owns the bytes (avoids a leak / use-after-free). + var cosBuf = [Float32](repeating: 0, count: count) + var sinBuf = [Float32](repeating: 0, count: count) + let numPairsPerDim = numPairs / 2 + let dimHalf = qHeadDim / 2 + for h in 0.. SiLU(linear1)*gate -> linear2 -> masked residual. +struct TransformerFFNBlock { + let resultTensor: MPSGraphTensor + + init( + graph: MPSGraph, + sourceTensor: MPSGraphTensor, + maskTensor: MPSGraphTensor, + descriptor: SWTransformerFFNBlockDesc, + nnXLen: NSNumber, + nnYLen: NSNumber + ) { + let numChannels = descriptor.numChannels + + // 1. RMSNorm + let normed = TransformerRMSNormLayer( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: descriptor.preLN).resultTensor + let normedNHWC = graph.transpose(normed, permutation: [0, 2, 3, 1], name: nil) + + // 2. linear1 + gate, both [B*seq, ffnChannels] + let a = MatMulLayer(graph: graph, descriptor: descriptor.linear1, sourceTensor: normedNHWC).resultTensor + let gate = MatMulLayer(graph: graph, descriptor: descriptor.linearGate, sourceTensor: normedNHWC).resultTensor + + // 3. SwiGLU: SiLU(a) * gate, SiLU(a) = a * sigmoid(a) + let siluA = graph.multiplication(a, graph.sigmoid(with: a, name: nil), name: nil) + let h = graph.multiplication(siluA, gate, name: nil) + + // 4. linear2 -> [B*seq, numChannels] + let out = MatMulLayer(graph: graph, descriptor: descriptor.linear2, sourceTensor: h).resultTensor + + // 5. reshape to NHWC then NCHW, masked residual + let outNHWC = graph.reshape(out, shape: [-1, nnYLen, nnXLen, numChannels as NSNumber], name: nil) + let outNCHW = graph.transpose(outNHWC, permutation: [0, 3, 1, 2], name: nil) + let masked = graph.multiplication(outNCHW, maskTensor, name: nil) + resultTensor = graph.addition(sourceTensor, masked, name: nil) + } +} + // MARK: - Block Implementations /// A class that represents a Residual Block layer @@ -1241,6 +1763,26 @@ struct BlockStack { optimizeIdentityMask: optimizeIdentityMask) blockInput = ordinary.resultTensor + case let attnDescriptor as SWTransformerAttentionBlockDesc: + let attn = TransformerAttentionBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: attnDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + blockInput = attn.resultTensor + case let ffnDescriptor as SWTransformerFFNBlockDesc: + let ffn = TransformerFFNBlock( + graph: graph, + sourceTensor: sourceTensor, + maskTensor: maskTensor, + descriptor: ffnDescriptor, + nnXLen: nnXLen, + nnYLen: nnYLen) + + blockInput = ffn.resultTensor default: blockInput = sourceTensor } @@ -1472,6 +2014,10 @@ class SGFMetadataEncoder { // MARK: - Trunk +/// Trunk-tip normalization kind, mirroring desc.h TRUNK_NORM_KIND_* (the value is serialized in the model). +let TRUNK_NORM_KIND_STANDARD = 0 // BatchNorm or BiasMask (existing) +let TRUNK_NORM_KIND_RMSNORM = 1 // RMSNorm + /// A class that describes a trunk for a neural network public class SWTrunkDesc { let version: Int @@ -1483,7 +2029,9 @@ public class SWTrunkDesc { let initialMatMul: SWMatMulLayerDesc let sgfMetadataEncoder: SWSGFMetadataEncoderDesc? let blockDescriptors: [BlockDescriptor] + let trunkNormKind: Int let trunkTipBN: SWBatchNormLayerDesc + let trunkTipRMSNorm: SWRMSNormLayerDesc let trunkTipActivation: ActivationKind init( @@ -1496,7 +2044,9 @@ public class SWTrunkDesc { initialMatMul: SWMatMulLayerDesc, sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, blockDescriptors: [BlockDescriptor], + trunkNormKind: Int, trunkTipBN: SWBatchNormLayerDesc, + trunkTipRMSNorm: SWRMSNormLayerDesc, trunkTipActivation: ActivationKind ) { self.version = version @@ -1508,7 +2058,9 @@ public class SWTrunkDesc { self.initialMatMul = initialMatMul self.sgfMetadataEncoder = sgfMetadataEncoder self.blockDescriptors = blockDescriptors + self.trunkNormKind = trunkNormKind self.trunkTipBN = trunkTipBN + self.trunkTipRMSNorm = trunkTipRMSNorm self.trunkTipActivation = trunkTipActivation } } @@ -1523,7 +2075,9 @@ public func createSWTrunkDesc( initialMatMul: SWMatMulLayerDesc, sgfMetadataEncoder: SWSGFMetadataEncoderDesc?, blockDescriptors: [BlockDescriptor], + trunkNormKind: Int32, trunkTipBN: SWBatchNormLayerDesc, + trunkTipRMSNorm: SWRMSNormLayerDesc, trunkTipActivation: ActivationKind ) -> SWTrunkDesc { return SWTrunkDesc( @@ -1536,7 +2090,9 @@ public func createSWTrunkDesc( initialMatMul: initialMatMul, sgfMetadataEncoder: sgfMetadataEncoder, blockDescriptors: blockDescriptors, + trunkNormKind: Int(trunkNormKind), trunkTipBN: trunkTipBN, + trunkTipRMSNorm: trunkTipRMSNorm, trunkTipActivation: trunkTipActivation) } @@ -1632,21 +2188,33 @@ struct Trunk { nnYLen: nnYLen, optimizeIdentityMask: optimizeIdentityMask) - let trunkTipBN = BatchNormLayer( - graph: graph, - sourceTensor: blocks.resultTensor, - maskTensor: maskTensor, - descriptor: descriptor.trunkTipBN, - nnXLen: nnXLen, - nnYLen: nnYLen, - optimizeIdentityMask: optimizeIdentityMask) + // RMSNorm trunk tip uses a fused activation; standard uses BatchNorm followed by a separate activation. + if descriptor.trunkNormKind == TRUNK_NORM_KIND_RMSNORM { + let trunkTipRMSNorm = TrunkRMSNormLayer( + graph: graph, + sourceTensor: blocks.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.trunkTipRMSNorm, + activationKind: descriptor.trunkTipActivation) - let trunkTipActivation = ActivationLayer( - graph: graph, - sourceTensor: trunkTipBN.resultTensor, - activationKind: descriptor.trunkTipActivation) + resultTensor = trunkTipRMSNorm.resultTensor + } else { + let trunkTipBN = BatchNormLayer( + graph: graph, + sourceTensor: blocks.resultTensor, + maskTensor: maskTensor, + descriptor: descriptor.trunkTipBN, + nnXLen: nnXLen, + nnYLen: nnYLen, + optimizeIdentityMask: optimizeIdentityMask) + + let trunkTipActivation = ActivationLayer( + graph: graph, + sourceTensor: trunkTipBN.resultTensor, + activationKind: descriptor.trunkTipActivation) - resultTensor = trunkTipActivation.resultTensor + resultTensor = trunkTipActivation.resultTensor + } assert(resultTensor.shape?.count == 4) }