Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
914 changes: 895 additions & 19 deletions cpp/external/katagocoreml/src/builder/MILBuilder.cpp

Large diffs are not rendered by default.

65 changes: 65 additions & 0 deletions cpp/external/katagocoreml/src/builder/MILBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ class MILBuilder {
bool m_optimize_identity_mask;
bool m_use_fp16;
bool m_use_fp16_io;
// FP32-in-FP16-mode escalations all run off the FP16-only ANE, so they apply ONLY to transformer
// trunks (attention widens activation range, overflowing FP16 conv/matmul/pooling accumulation).
// Plain convnets run pure FP16 on the ANE -- the long-standing pre-tier path, verified to pass
// testgpuerror (b18c384nbt) and ~2.3x faster than forcing their per-block global pooling to FP32.
// For transformers: narrow trunks (<256) build fully FP32; wider ones use non-spatial FP32 (matmuls +
// pooling) plus, for very wide trunks (>=320), conv FP32. RMSNorm reductions: FP32 when m_use_fp16.
static constexpr int CONV_FP32_MIN_TRUNK_CHANNELS = 320; // transformer convs run FP32 at/above this width
static constexpr int FULL_FP32_MAX_TRUNK_CHANNELS = 256; // transformer trunks below this build fully FP32
bool m_nonspatial_fp32 = false; // = m_use_fp16 && hasTransformer (matmuls + global pooling)
bool m_conv_fp32 = false; // = m_use_fp16 && hasTransformer && trunk_channels >= CONV_FP32_MIN_...
int m_min_batch_size;
int m_max_batch_size;
CoreML::Specification::MILSpec::DataType m_weight_dtype;
Expand Down Expand Up @@ -102,6 +112,23 @@ class MILBuilder {
const std::string& dtype,
const std::vector<int64_t>& shape);

// Cast to a tensor with FULLY-specified dims (no forced batch dim like addCastOp). Use for
// weight tensors (fixed [in,out] dims) when running an otherwise-FP16 op in FP32. Returns the
// new tensor name. dims use -1 for an unknown/batch dim, >=0 for a constant dim.
std::string castFixed(CoreML::Specification::MILSpec::Block* block,
const std::string& input,
const std::string& dtype,
const std::vector<int64_t>& dims);

// Emit global pooling, running it in FP32 when m_nonspatial_fp32 (cast input/mask up, pool,
// cast the pooled features back to FP16). valueVariant selects the value-head pooling variant.
void addGlobalPoolingFp32(CoreML::Specification::MILSpec::Block* block,
const std::string& input,
const std::string& mask,
int channels,
const std::string& output,
bool valueVariant);

void addConvOp(CoreML::Specification::MILSpec::Block* block,
const std::string& input,
const ConvLayerDesc& layer,
Expand All @@ -120,6 +147,44 @@ class MILBuilder {
int rank,
int channels);

void addSiluOps(CoreML::Specification::MILSpec::Block* block,
const std::string& input,
const std::string& output,
int rank,
int channels);

// Generic output-shape setter: dims with -1 entries become unknown/dynamic dimensions.
void setShape(CoreML::Specification::MILSpec::Operation* op,
const std::string& name,
const std::vector<int64_t>& dims);

// Lightweight transformer RMSNorm (weight only, per-position over channels). NCHW in/out.
std::string addTransformerRMSNorm(CoreML::Specification::MILSpec::Block* block,
const std::string& input,
const TransformerRMSNormDesc& desc,
const std::string& mask,
const std::string& prefix);

// Full RMSNorm at trunk tip: gamma/beta, spatial or per-position, fused activation. NCHW in/out.
std::string addTrunkRMSNorm(CoreML::Specification::MILSpec::Block* block,
const std::string& input,
const RMSNormLayerDesc& desc,
const ActivationLayerDesc& act,
const std::string& mask,
const std::string& prefix);

std::string buildTransformerAttentionBlock(CoreML::Specification::MILSpec::Block* block,
const std::string& input,
const TransformerAttentionBlockDesc& block_desc,
const std::string& mask,
const std::string& prefix);

std::string buildTransformerFFNBlock(CoreML::Specification::MILSpec::Block* block,
const std::string& input,
const TransformerFFNBlockDesc& block_desc,
const std::string& mask,
const std::string& prefix);

void addGlobalPoolingOps(CoreML::Specification::MILSpec::Block* block,
const std::string& input,
const std::string& mask,
Expand Down
4 changes: 3 additions & 1 deletion cpp/external/katagocoreml/src/builder/Operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ KataGoOps::KataGoOps(int board_x_size, int board_y_size, bool optimize_identity_

std::string KataGoOps::registerWeight(const std::string& name,
const std::vector<float>& data,
const std::vector<int64_t>& shape) {
const std::vector<int64_t>& shape,
bool is_fp32) {
WeightEntry entry;
entry.name = name;
entry.data = data;
entry.shape = shape;
entry.blob_offset = 0; // Will be set during serialization
entry.is_fp32 = is_fp32;
m_weights.push_back(std::move(entry));
return name;
}
Expand Down
7 changes: 5 additions & 2 deletions cpp/external/katagocoreml/src/builder/Operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ struct WeightEntry {
std::vector<float> data;
std::vector<int64_t> shape;
uint64_t blob_offset = 0; // Set during serialization
bool is_fp32 = false; // Store as FP32 (set when the const was declared FP32, e.g. inside an
// FP32 sub-region of an otherwise-FP16 model). Else stored per global mode.
};

/// Precomputed constants for identity mask optimization
Expand Down Expand Up @@ -51,10 +53,11 @@ class KataGoOps {
/// Get precomputed mask constants
const MaskConstants& getMaskConstants() const { return m_mask_constants; }

/// Register a weight tensor and return its reference name
/// Register a weight tensor and return its reference name. is_fp32 marks it for FP32 storage.
std::string registerWeight(const std::string& name,
const std::vector<float>& data,
const std::vector<int64_t>& shape);
const std::vector<int64_t>& shape,
bool is_fp32 = false);

/// Get all registered weights
const std::vector<WeightEntry>& getWeights() const { return m_weights; }
Expand Down
142 changes: 127 additions & 15 deletions cpp/external/katagocoreml/src/parser/KataGoParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ ActivationLayerDesc KataGoParser::parseActivationLayer(int model_version) {
layer.activation_type = ActivationType::ReLU;
} else if (activation_str == "ACTIVATION_MISH") {
layer.activation_type = ActivationType::Mish;
} else if (activation_str == "ACTIVATION_SILU") {
layer.activation_type = ActivationType::Silu;
} else {
throw std::runtime_error("Unknown activation type: " + activation_str);
}
Expand Down Expand Up @@ -420,6 +422,98 @@ static void checkBlockChannels(const std::string& block_name, const std::string&
}
}

TransformerRMSNormDesc KataGoParser::parseTransformerRMSNorm() {
TransformerRMSNormDesc layer;
layer.name = readString();
layer.num_channels = readInt();
layer.epsilon = readFloat();
if (layer.num_channels < 1) {
throw std::runtime_error(layer.name + ": transformer rmsnorm numChannels must be >= 1");
}
layer.weight = readFloats(layer.num_channels, layer.name + "/weight");
return layer;
}

RMSNormLayerDesc KataGoParser::parseRMSNormLayer() {
RMSNormLayerDesc layer;
layer.name = readString();
layer.num_channels = readInt();
layer.epsilon = readFloat();
layer.spatial = (readInt() != 0);
layer.cgroup_size = readInt();
if (layer.num_channels < 1) {
throw std::runtime_error(layer.name + ": rmsnorm numChannels must be >= 1");
}
if (layer.cgroup_size != 0) {
throw std::runtime_error(layer.name + ": grouped spatial RMSNorm is not supported");
}
layer.gamma = readFloats(layer.num_channels, layer.name + "/gamma");
layer.beta = readFloats(layer.num_channels, layer.name + "/beta");
return layer;
}

TransformerAttentionBlockDesc KataGoParser::parseTransformerAttentionBlock(int model_version) {
TransformerAttentionBlockDesc block;
block.name = readString();
block.num_heads = readInt();
block.num_kv_heads = readInt();
block.q_head_dim = readInt();
block.v_head_dim = readInt();
block.use_rope = (readInt() != 0);
block.learnable_rope = (readInt() != 0);

if (block.num_heads < 1 || block.num_kv_heads < 1 || (block.num_heads % block.num_kv_heads != 0)) {
throw std::runtime_error(block.name + ": invalid numHeads/numKVHeads");
}
if (block.use_rope && (block.q_head_dim % 2 != 0)) {
throw std::runtime_error(block.name + ": qHeadDim must be even when RoPE is used");
}

block.pre_ln = parseTransformerRMSNorm();
block.q_proj = parseMatMulLayer();
block.k_proj = parseMatMulLayer();
block.v_proj = parseMatMulLayer();
block.out_proj = parseMatMulLayer();

if (block.use_rope) {
if (block.learnable_rope) {
readString(); // ropeFreqs name
block.rope_num_kv_heads = readInt();
block.rope_num_pairs = readInt();
int rope_dim2 = readInt();
if (block.rope_num_kv_heads != block.num_kv_heads ||
block.rope_num_pairs != block.q_head_dim / 2 || rope_dim2 != 2) {
throw std::runtime_error(block.name + ": invalid learnable rope header");
}
block.rope_freqs = readFloats(
static_cast<size_t>(block.rope_num_kv_heads) * block.rope_num_pairs * 2,
block.name + "/rope_freqs");
} else {
readString(); // ropeTheta name
block.rope_theta = readFloat();
}
}
return block;
}

TransformerFFNBlockDesc KataGoParser::parseTransformerFFNBlock(int model_version) {
TransformerFFNBlockDesc block;
block.name = readString();
block.num_channels = readInt();
block.ffn_channels = readInt();
block.use_swiglu = (readInt() != 0);
if (block.num_channels < 1 || block.ffn_channels < 1) {
throw std::runtime_error(block.name + ": transformer ffn channels must be positive");
}
block.pre_ln = parseTransformerRMSNorm();
block.linear1 = parseMatMulLayer();
if (block.use_swiglu) {
block.linear_gate = parseMatMulLayer();
}
block.linear2 = parseMatMulLayer();
return block;
}

std::vector<BlockEntry> KataGoParser::parseBlockStack(int model_version, int num_blocks, int trunk_num_channels) {
std::vector<BlockEntry> blocks;
blocks.reserve(num_blocks);
Expand Down Expand Up @@ -449,6 +543,14 @@ std::vector<BlockEntry> KataGoParser::parseBlockStack(int model_version, int num
desc.pre_bn.num_channels,
desc.post_conv.out_channels, trunk_num_channels);
entry.block = std::make_shared<BlockDesc>(std::move(desc));
} else if (block_kind_name == "transformer_attention_block") {
entry.block_kind = TRANSFORMER_ATTENTION_BLOCK_KIND;
auto desc = parseTransformerAttentionBlock(model_version);
entry.block = std::make_shared<BlockDesc>(std::move(desc));
} else if (block_kind_name == "transformer_ffn_block") {
entry.block_kind = TRANSFORMER_FFN_BLOCK_KIND;
auto desc = parseTransformerFFNBlock(model_version);
entry.block = std::make_shared<BlockDesc>(std::move(desc));
} else {
throw std::runtime_error("Unknown block kind: " + block_kind_name);
}
Expand Down Expand Up @@ -506,15 +608,15 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version)
}

// Version >= 15 writes the trunk norm kind followed by 5 unused int parameters.
// This CoreML parser only supports the standard trunk norm kind (0 = BatchNorm/BiasMask);
// RMSNorm (used by transformer/rmsnorm models, kind != 0) is not implemented here, so reject it
// defensively rather than silently parsing it as standard norm and producing wrong outputs.
// Unlike upstream's CoreML parser (which rejects any non-standard norm), this fork
// implements RMSNorm, so we capture the kind here instead of throwing. The 5 trailing
// ints are reserved and still expected to be zero.
if (model_version >= 15) {
int trunk_norm_kind = readInt();
if (trunk_norm_kind != 0) {
throw std::runtime_error(trunk.name + ": unsupported trunk norm kind " +
std::to_string(trunk_norm_kind) +
" (this CoreML parser only supports standard trunk norm, not RMSNorm)");
trunk.trunk_norm_kind = readInt();
if (trunk.trunk_norm_kind != TRUNK_NORM_KIND_STANDARD &&
trunk.trunk_norm_kind != TRUNK_NORM_KIND_RMSNORM) {
throw std::runtime_error(trunk.name + ": unknown/unsupported trunk norm kind " +
std::to_string(trunk.trunk_norm_kind));
}
for (int i = 0; i < 5; i++) {
int unused = readInt();
Expand Down Expand Up @@ -561,14 +663,24 @@ TrunkDesc KataGoParser::parseTrunk(int model_version, int meta_encoder_version)
// Parse residual blocks
trunk.blocks = parseBlockStack(model_version, trunk.num_blocks, trunk.trunk_num_channels);

trunk.trunk_tip_bn = parseBatchNormLayer();
trunk.trunk_tip_activation = parseActivationLayer(model_version);
if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) {
throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" +
std::to_string(trunk.trunk_tip_bn.num_channels) +
") != trunkNumChannels (" +
std::to_string(trunk.trunk_num_channels) + ")");
if (trunk.trunk_norm_kind == TRUNK_NORM_KIND_STANDARD) {
trunk.trunk_tip_bn = parseBatchNormLayer();
if (trunk.trunk_tip_bn.num_channels != trunk.trunk_num_channels) {
throw std::runtime_error(trunk.name + ": trunkTipBN.numChannels (" +
std::to_string(trunk.trunk_tip_bn.num_channels) +
") != trunkNumChannels (" +
std::to_string(trunk.trunk_num_channels) + ")");
}
} else {
trunk.trunk_tip_rms_norm = parseRMSNormLayer();
if (trunk.trunk_tip_rms_norm.num_channels != trunk.trunk_num_channels) {
throw std::runtime_error(trunk.name + ": trunkTipRMSNorm.numChannels (" +
std::to_string(trunk.trunk_tip_rms_norm.num_channels) +
") != trunkNumChannels (" +
std::to_string(trunk.trunk_num_channels) + ")");
}
}
trunk.trunk_tip_activation = parseActivationLayer(model_version);

return trunk;
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/external/katagocoreml/src/parser/KataGoParser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@ class KataGoParser {
ActivationLayerDesc parseActivationLayer(int model_version);
MatMulLayerDesc parseMatMulLayer();
MatBiasLayerDesc parseMatBiasLayer();
TransformerRMSNormDesc parseTransformerRMSNorm();
RMSNormLayerDesc parseRMSNormLayer();

// Block parsing functions
ResidualBlockDesc parseResidualBlock(int model_version);
GlobalPoolingResidualBlockDesc parseGlobalPoolingResidualBlock(int model_version);
NestedBottleneckResidualBlockDesc parseNestedBottleneckBlock(int model_version, int trunk_num_channels);
TransformerAttentionBlockDesc parseTransformerAttentionBlock(int model_version);
TransformerFFNBlockDesc parseTransformerFFNBlock(int model_version);
std::vector<BlockEntry> parseBlockStack(int model_version, int num_blocks, int trunk_num_channels);

// Component parsing functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ size_t WeightSerializer::serialize(std::vector<WeightEntry>& weights,
size_t total_bytes = 0;

for (auto& entry : weights) {
if (use_fp16) {
// Per-weight precision: store FP16 only when the global mode is FP16 AND this weight was not
// declared FP32 (entry.is_fp32 marks consts inside an FP32 sub-region of an FP16 model), so
// stored bytes stay consistent with each const's declared dtype.
const bool store_fp16 = use_fp16 && !entry.is_fp32;
if (store_fp16) {
// Convert FP32 weights to FP16
std::vector<MILBlob::Fp16> fp16_data(entry.data.size());
for (size_t i = 0; i < entry.data.size(); ++i) {
Expand Down
Loading
Loading