Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 2 additions & 34 deletions openxla/patches/20240901-001-Various-macOS-QOL-enchancements.patch
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,9 @@ PR: https://github.com/openxla/xla/pull/16696

Co-authored-by: Steeve Morin <steeve@zml.ai>
---
tensorflow.bazelrc | 5 ++---
xla/pjrt/c/BUILD | 18 ++++++++++--------
2 files changed, 12 insertions(+), 11 deletions(-)
xla/pjrt/c/BUILD | 18 ++++++++++--------
1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/tensorflow.bazelrc b/tensorflow.bazelrc
index f2ad3f6169..032cca5657 100644
--- a/tensorflow.bazelrc
+++ b/tensorflow.bazelrc
@@ -688,7 +688,6 @@ test:release_arm64_linux --flaky_test_attempts=3
build:release_cpu_macos --config=avx_linux

# Base build configs for macOS
-build:release_macos_base --action_env DEVELOPER_DIR=/Applications/Xcode.app/Contents/Developer
build:release_macos_base --define=no_nccl_support=true --output_filter=^$

# Ensure release_base is set on mac
@@ -701,7 +700,7 @@ build:release_macos_x86 --config=avx_linux
build:release_macos_x86 --cpu=darwin
# Target Catalina as the minimum compatible OS version
build:release_macos_x86 --macos_minimum_os=10.15
-build:release_macos_x86 --action_env MACOSX_DEPLOYMENT_TARGET=10.15
+build:release_macos_x86 --macos_sdk_version=10.15

# Build configs for macOS Arm64
build:release_macos_arm64 --config=release_macos_base
@@ -709,7 +708,7 @@ build:release_macos_arm64 --cpu=darwin_arm64
build:release_macos_arm64 --define=tensorflow_mkldnn_contraction_kernel=0
# Target Moneterey as the minimum compatible OS version
build:release_macos_arm64 --macos_minimum_os=12.0
-build:release_macos_arm64 --action_env MACOSX_DEPLOYMENT_TARGET=12.0
+build:release_macos_arm64 --macos_sdk_version=12.0

# Base test configs for macOS
test:release_macos_base --verbose_failures=true --local_test_jobs=HOST_CPUS
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
index a0485b6a43..6f67ee6b78 100644
--- a/xla/pjrt/c/BUILD
Expand Down Expand Up @@ -101,4 +70,3 @@ index a0485b6a43..6f67ee6b78 100644
"//xla/stream_executor:cuda_platform",
--
2.39.5 (Apple Git-154)

152 changes: 152 additions & 0 deletions openxla/patches/20250225-001-Patch-cudnn-sdpa.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
From 3039cbe576b79c489920c165e12c1fdc08a5321a Mon Sep 17 00:00:00 2001
From: Hugo Mano <hugo@zml.ai>
Date: Tue, 25 Feb 2025 10:44:44 +0000
Subject: [PATCH] pagged sdpa

---
xla/service/gpu/backend_configs.proto | 2 +
.../transforms/cudnn_custom_call_compiler.cc | 23 ++++++++++-
xla/stream_executor/cuda/cuda_dnn.cc | 39 +++++++++++++++++++
xla/stream_executor/cuda/cuda_dnn.h | 5 +++
4 files changed, 67 insertions(+), 2 deletions(-)

diff --git a/xla/service/gpu/backend_configs.proto b/xla/service/gpu/backend_configs.proto
index a7ff5bfba2..6ddbb5424e 100644
--- a/xla/service/gpu/backend_configs.proto
+++ b/xla/service/gpu/backend_configs.proto
@@ -283,6 +283,8 @@ message CudnnfMHABackendConfig {
// Only used with packed layout
// ignored if the valued <= 1
int32 max_seg_per_batch = 25;
+
+ optional int32 max_sequence_length_kv = 26;
}

// Backend config for a general custom call instruction, e.g. XLA FFI.
diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
index 4c15388a0a..2a15f3304c 100644
--- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
+++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
@@ -139,6 +139,22 @@ absl::StatusOr<se::gpu::CudnnGraph> BuildGraphForCustomCallToForwardFMHA(
TF_ASSIGN_OR_RETURN(bias, TensorDescriptorFor(bias_hlo.shape()));
}

+ std::optional<se::dnn::TensorDescriptor> sequence_length_q;
+ std::optional<se::dnn::TensorDescriptor> sequence_length_kv;
+ std::optional<se::dnn::TensorDescriptor> page_table_k;
+ std::optional<se::dnn::TensorDescriptor> page_table_v;
+
+ if (custom_call->operand_count() == 7) {
+ TF_ASSIGN_OR_RETURN(sequence_length_q,
+ TensorDescriptorFor(custom_call->operand(4)->shape()));
+ TF_ASSIGN_OR_RETURN(sequence_length_kv,
+ TensorDescriptorFor(custom_call->operand(5)->shape()));
+ TF_ASSIGN_OR_RETURN(page_table_k,
+ TensorDescriptorFor(custom_call->operand(6)->shape()));
+ TF_ASSIGN_OR_RETURN(page_table_v,
+ TensorDescriptorFor(custom_call->operand(7)->shape()));
+ }
+
const double dropout_rate = config.dropout_rate();

TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type,
@@ -148,13 +164,16 @@ absl::StatusOr<se::gpu::CudnnGraph> BuildGraphForCustomCallToForwardFMHA(

const int sliding_window_length = config.sliding_window_length();
const int max_seg_per_batch = config.max_seg_per_batch();
+ std::optional<int> max_sequence_length_kv = config.max_sequence_length_kv();
TF_ASSIGN_OR_RETURN(
se::gpu::CudnnGraph graph,
se::gpu::GetCudnnFlashAttentionOperationGraph(
- dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, bias, activation,
+ dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, bias, sequence_length_q,
+ sequence_length_kv, activation,
static_cast<float>(config.fmha_scale()), dropout_rate > 0.0,
dropout_rate, dnn_mask_type, sliding_window_length,
- max_seg_per_batch));
+ page_table_k, page_table_v,
+ max_sequence_length_kv, max_seg_per_batch));
return graph;
}

diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc
index 808870837f..922477beb8 100644
--- a/xla/stream_executor/cuda/cuda_dnn.cc
+++ b/xla/stream_executor/cuda/cuda_dnn.cc
@@ -4981,9 +4981,14 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionOperationGraph(
const dnn::MatmulTensorDescriptor& v_descriptor,
const dnn::TensorDescriptor& o_descriptor,
const std::optional<dnn::TensorDescriptor> bias_descriptor,
+ const std::optional<dnn::TensorDescriptor> sequence_length_q,
+ const std::optional<dnn::TensorDescriptor> sequence_length_kv,
const std::optional<dnn::TensorDescriptor> stats_descriptor, double scale,
const bool use_dropout, const std::optional<double> dropout_rate,
const dnn::FMHAMaskKind mask_type, const int sliding_window_length,
+ const std::optional<dnn::TensorDescriptor> page_table_k,
+ const std::optional<dnn::TensorDescriptor> page_table_v,
+ const std::optional<int> max_sequence_length_kv,
const int max_seg_per_batch) {
using cudnn_frontend::graph::Tensor_attributes;

@@ -5139,6 +5144,40 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionOperationGraph(
if (sliding_window_length > 0) {
sdpa_options.set_sliding_window_length(sliding_window_length);
}
+
+ if (sequence_length_q && sequence_length_kv && page_table_k && page_table_v && max_sequence_length_kv) {
+ auto seq_q = graph.tensor(Tensor_attributes()
+ .set_name("seq_q")
+ .set_uid(next_uid())
+ .set_dim(sequence_length_q->dimensions())
+ .set_stride(sequence_length_q->GetLogicalStrides())
+ .set_data_type(cudnn_frontend::DataType_t::INT32));
+ auto seq_kv = graph.tensor(Tensor_attributes()
+ .set_name("seq_kv")
+ .set_uid(next_uid())
+ .set_dim(sequence_length_kv->dimensions())
+ .set_stride(sequence_length_kv->GetLogicalStrides())
+ .set_data_type(cudnn_frontend::DataType_t::INT32));
+ sdpa_options.set_padding_mask(true).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv);
+
+ auto page_table_k_ = graph.tensor(Tensor_attributes()
+ .set_name("page_table_k")
+ .set_uid(next_uid())
+ .set_dim(page_table_k->dimensions())
+ .set_stride(page_table_k->GetLogicalStrides())
+ .set_data_type(cudnn_frontend::DataType_t::INT32));
+ auto page_table_v_ = graph.tensor(Tensor_attributes()
+ .set_name("page_table_v")
+ .set_uid(next_uid())
+ .set_dim(page_table_v->dimensions())
+ .set_stride(page_table_v->GetLogicalStrides())
+ .set_data_type(cudnn_frontend::DataType_t::INT32));
+
+ sdpa_options.set_paged_attention_k_table(page_table_k_);
+ sdpa_options.set_paged_attention_v_table(page_table_v_);
+ sdpa_options.set_paged_attention_max_seq_len_kv(max_sequence_length_kv.value());
+ }
+
// Add SDPA to the graph.
auto [o_tensor, stats_tensor] =
graph.sdpa(q_tensor, k_tensor, v_tensor, sdpa_options);
diff --git a/xla/stream_executor/cuda/cuda_dnn.h b/xla/stream_executor/cuda/cuda_dnn.h
index 946e419311..9146eaa785 100644
--- a/xla/stream_executor/cuda/cuda_dnn.h
+++ b/xla/stream_executor/cuda/cuda_dnn.h
@@ -714,9 +714,14 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionOperationGraph(
const dnn::MatmulTensorDescriptor& v_descriptor,
const dnn::TensorDescriptor& o_descriptor,
const std::optional<dnn::TensorDescriptor> bias_descriptor,
+ const std::optional<dnn::TensorDescriptor> sequence_length_q,
+ const std::optional<dnn::TensorDescriptor> sequence_length_kv,
const std::optional<dnn::TensorDescriptor> stats_descriptor, double scale,
const bool use_dropout, const std::optional<double> dropout_rate,
const dnn::FMHAMaskKind mask_type, const int sliding_window_length,
+ const std::optional<dnn::TensorDescriptor> page_table_k,
+ const std::optional<dnn::TensorDescriptor> page_table_v,
+ const std::optional<int> max_sequence_length_kv,
const int max_seg_per_batch);

absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionF8OperationGraph(
--
2.34.1
52 changes: 52 additions & 0 deletions openxla/patches/20250225-002-Patch-cudnn-sdpa.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
From c4c568754c53674072317f322d511fcba93d7871 Mon Sep 17 00:00:00 2001
From: Hugo Mano <hugo@zml.ai>
Date: Tue, 25 Feb 2025 15:54:15 +0000
Subject: [PATCH] patche

---
xla/service/gpu/transforms/cudnn_custom_call_compiler.cc | 2 +-
xla/stream_executor/cuda/cuda_dnn.cc | 9 ++++++++-
2 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
index 2a15f3304c..1604494003 100644
--- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
+++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
@@ -144,7 +144,7 @@ absl::StatusOr<se::gpu::CudnnGraph> BuildGraphForCustomCallToForwardFMHA(
std::optional<se::dnn::TensorDescriptor> page_table_k;
std::optional<se::dnn::TensorDescriptor> page_table_v;

- if (custom_call->operand_count() == 7) {
+ if (custom_call->operand_count() == 8) {
TF_ASSIGN_OR_RETURN(sequence_length_q,
TensorDescriptorFor(custom_call->operand(4)->shape()));
TF_ASSIGN_OR_RETURN(sequence_length_kv,
diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc
index 922477beb8..1038d4b43f 100644
--- a/xla/stream_executor/cuda/cuda_dnn.cc
+++ b/xla/stream_executor/cuda/cuda_dnn.cc
@@ -5025,6 +5025,11 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionOperationGraph(
std::vector<int64_t> v_dims =
v_descriptor.GetCudnnCompatibleDimensions(false);

+
+ VLOG(4) << "\n GetCudnnCompatibleDimensions: q_dims: " << absl::StrJoin(q_dims, ",");
+ VLOG(4) << "\n GetCudnnCompatibleDimensions: k_dims: " << absl::StrJoin(k_dims, ",");
+ VLOG(4) << "\n GetCudnnCompatibleDimensions: v_dims: " << absl::StrJoin(v_dims, ",");
+
if (max_seg_per_batch > 1) {
FixDimsForRaggedOffset(q_dims, max_seg_per_batch);
FixDimsForRaggedOffset(k_dims, max_seg_per_batch);
@@ -5037,7 +5042,9 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionOperationGraph(
.set_dim(q_dims)
.set_stride(q_descriptor.GetCudnnCompatibleStrides(true))
.set_uid(next_uid()));
-
+ VLOG(4) << "\n q_strides: " << absl::StrJoin(q_descriptor.GetCudnnCompatibleStrides(true), ",");
+ VLOG(4) << "\n k_strides: " << absl::StrJoin(k_descriptor.GetCudnnCompatibleStrides(true), ",");
+ VLOG(4) << "\n v_strides: " << absl::StrJoin(v_descriptor.GetCudnnCompatibleStrides(false), ",");
std::shared_ptr<Tensor_attributes> k_tensor =
graph.tensor(Tensor_attributes()
.set_name("K")
--
2.34.1
25 changes: 25 additions & 0 deletions openxla/patches/20250225-003-Patch-cudnn-sdpa.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
From b81051b2d45f486652a72cf88c355b4db9c9cd6b Mon Sep 17 00:00:00 2001
From: Hugo Mano <hugo@zml.ai>
Date: Tue, 25 Feb 2025 16:50:40 +0000
Subject: [PATCH] remote attn scale

---
xla/stream_executor/cuda/cuda_dnn.cc | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc
index 1038d4b43f..d87c8acc43 100644
--- a/xla/stream_executor/cuda/cuda_dnn.cc
+++ b/xla/stream_executor/cuda/cuda_dnn.cc
@@ -5064,8 +5064,7 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionOperationGraph(
cudnn_frontend::graph::SDPA_attributes sdpa_options;
sdpa_options.set_name("flash_attention")
.set_is_inference(stats_descriptor == std::nullopt)
- .set_causal_mask(is_causal)
- .set_attn_scale(scale);
+ .set_causal_mask(is_causal);

// Setting bias
if (bias_descriptor.has_value()) {
--
2.34.1
94 changes: 94 additions & 0 deletions openxla/patches/20250225-004-Patch-cudnn-sdpa.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
diff --git a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
index 6199a84562..6a325e7433 100644
--- a/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
+++ b/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc
@@ -133,26 +133,26 @@ absl::StatusOr<se::gpu::CudnnGraph> BuildGraphForCustomCallToForwardFMHA(
}

std::optional<se::dnn::TensorDescriptor> bias;
- if (kind == CudnnfMHAKind::kScaleBiasSoftmax ||
- kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout) {
- const HloInstruction &bias_hlo = *custom_call->operand(3);
- TF_ASSIGN_OR_RETURN(bias, TensorDescriptorFor(bias_hlo.shape()));
- }
+ //if (kind == CudnnfMHAKind::kScaleBiasSoftmax ||
+ // kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout) {
+ // const HloInstruction &bias_hlo = *custom_call->operand(3);
+ // TF_ASSIGN_OR_RETURN(bias, TensorDescriptorFor(bias_hlo.shape()));
+ //}

std::optional<se::dnn::TensorDescriptor> sequence_length_q;
std::optional<se::dnn::TensorDescriptor> sequence_length_kv;
std::optional<se::dnn::TensorDescriptor> page_table_k;
std::optional<se::dnn::TensorDescriptor> page_table_v;

- if (custom_call->operand_count() == 8) {
+ if (custom_call->operand_count() == 7) {
TF_ASSIGN_OR_RETURN(sequence_length_q,
- TensorDescriptorFor(custom_call->operand(4)->shape()));
+ TensorDescriptorFor(custom_call->operand(3)->shape()));
TF_ASSIGN_OR_RETURN(sequence_length_kv,
- TensorDescriptorFor(custom_call->operand(5)->shape()));
+ TensorDescriptorFor(custom_call->operand(4)->shape()));
TF_ASSIGN_OR_RETURN(page_table_k,
- TensorDescriptorFor(custom_call->operand(6)->shape()));
+ TensorDescriptorFor(custom_call->operand(5)->shape()));
TF_ASSIGN_OR_RETURN(page_table_v,
- TensorDescriptorFor(custom_call->operand(7)->shape()));
+ TensorDescriptorFor(custom_call->operand(6)->shape()));
}

const double dropout_rate = config.dropout_rate();
diff --git a/xla/stream_executor/cuda/cuda_dnn.cc b/xla/stream_executor/cuda/cuda_dnn.cc
index 2d6bd67570..0720df385d 100644
--- a/xla/stream_executor/cuda/cuda_dnn.cc
+++ b/xla/stream_executor/cuda/cuda_dnn.cc
@@ -5079,27 +5079,27 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionOperationGraph(
// Setting actual seqlen
bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING ||
mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL;
- if (is_padding || max_seg_per_batch > 1) {
- // Get batch size
- auto b = q_dims[0];
- auto seq_q_tensor =
- graph.tensor(Tensor_attributes()
- .set_name("seq_q")
- .set_dim({b, 1, 1, 1})
- .set_stride({1, 1, 1, 1})
- .set_uid(next_uid())
- .set_data_type(cudnn_frontend::DataType_t::INT32));
- auto seq_kv_tensor =
- graph.tensor(Tensor_attributes()
- .set_name("seq_kv")
- .set_dim({b, 1, 1, 1})
- .set_stride({1, 1, 1, 1})
- .set_uid(next_uid())
- .set_data_type(cudnn_frontend::DataType_t::INT32));
- sdpa_options.set_padding_mask(true);
- sdpa_options.set_seq_len_q(seq_q_tensor);
- sdpa_options.set_seq_len_kv(seq_kv_tensor);
- }
+ //if (is_padding || max_seg_per_batch > 1) {
+ // // Get batch size
+ // auto b = q_dims[0];
+ // auto seq_q_tensor =
+ // graph.tensor(Tensor_attributes()
+ // .set_name("seq_q")
+ // .set_dim({b, 1, 1, 1})
+ // .set_stride({1, 1, 1, 1})
+ // .set_uid(next_uid())
+ // .set_data_type(cudnn_frontend::DataType_t::INT32));
+ // auto seq_kv_tensor =
+ // graph.tensor(Tensor_attributes()
+ // .set_name("seq_kv")
+ // .set_dim({b, 1, 1, 1})
+ // .set_stride({1, 1, 1, 1})
+ // .set_uid(next_uid())
+ // .set_data_type(cudnn_frontend::DataType_t::INT32));
+ // sdpa_options.set_padding_mask(true);
+ // sdpa_options.set_seq_len_q(seq_q_tensor);
+ // sdpa_options.set_seq_len_kv(seq_kv_tensor);
+ //}

std::shared_ptr<Tensor_attributes> offset_q;
if (max_seg_per_batch > 1) {
Loading