Skip to content

Commit c79fd4e

Browse files
committed
sync upstream and fix TTNN compat
1 parent 99cb4ad commit c79fd4e

3 files changed

Lines changed: 46 additions & 697 deletions

File tree

ggml/src/ggml-metalium/ggml-metalium.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"
1313
#include "ttnn/operations/eltwise/unary/unary.hpp"
1414
#include "ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.hpp"
15-
#include "ttnn/operations/normalization/softmax/device/softmax_op.hpp"
1615
#include "ttnn/tensor/shape/shape.hpp"
1716
#include "ttnn/tensor/storage.hpp"
1817
#include "ttnn/tensor/tensor.hpp"
@@ -1361,8 +1360,7 @@ static void ggml_backend_metalium_softmax(ggml_backend_metalium_context * ctx, s
13611360
x = ttnn::add(x, ttnn::multiply(*mask, positional_bias));
13621361
}
13631362
}
1364-
ttnn::DeviceComputeKernelConfig cfg = make_compute_kernel_config(x.device());
1365-
x = ttnn::operations::normalization::softmax(x, tt::tt_metal::operation::DEFAULT_OUTPUT_MEMORY_CONFIG, cfg, true);
1363+
x = ttnn::softmax(x, 3);
13661364
*dst_meta = {
13671365
.tensor = std::make_shared<tt::tt_metal::Tensor>(std::move(x)),
13681366
.ggtype = dst->type,
@@ -1620,7 +1618,7 @@ static bool ggml_backend_metalium_can_glu(const struct ggml_tensor * dst)
16201618
if(split) {
16211619
return true;
16221620
}
1623-
return dst->src[0]->ne[0] % 2 == 0;
1621+
return dst->src[0]->ne[0] % 2 == 0 && ggml_get_glu_op(dst) != GGML_GLU_OP_SWIGLU_OAI;
16241622
}
16251623

16261624
static void ggml_backend_metalium_glu(ggml_backend_metalium_context * ctx, struct ggml_tensor * dst)

ggml/src/ggml-metalium/metalium-pch.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "ttnn/operations/eltwise/binary/binary_composite.hpp"
1313
#include "ttnn/operations/eltwise/unary/unary.hpp"
1414
#include "ttnn/operations/moreh/moreh_group_norm/moreh_group_norm.hpp"
15-
#include "ttnn/operations/normalization/softmax/device/softmax_op.hpp"
1615
#include "ttnn/tensor/shape/shape.hpp"
1716
#include "ttnn/tensor/storage.hpp"
1817
#include "ttnn/tensor/tensor.hpp"

0 commit comments

Comments
 (0)