diff --git a/xla/backends/gpu/runtime/command_buffer_cmd.cc b/xla/backends/gpu/runtime/command_buffer_cmd.cc index 228d40ed935a4..0ede07cb6a30b 100644 --- a/xla/backends/gpu/runtime/command_buffer_cmd.cc +++ b/xla/backends/gpu/runtime/command_buffer_cmd.cc @@ -64,7 +64,6 @@ limitations under the License. #include "xla/core/collectives/communicator.h" #include "xla/core/collectives/reduction_kind.h" #include "xla/debug_options_flags.h" -#include "xla/service/rendezvous.h" #include "xla/executable_run_options.h" #include "xla/ffi/call_frame.h" #include "xla/ffi/ffi_api.h" @@ -996,23 +995,6 @@ absl::StatusOr TracedCommandBuffer::GetOrTraceCommandBuffer( return shift_right(capacity_ - 1).command_buffer.get(); } -bool TracedCommandBuffer::HasEntry( - const BufferAllocations* buffer_allocation) const { - absl::InlinedVector allocs; - allocs.reserve(allocs_indices_.size()); - for (auto& index : allocs_indices_) { - allocs.emplace_back(buffer_allocation->GetDeviceAddress(index)); - } - - for (size_t i = 0; i < capacity_; ++i) { - if (absl::c_equal(entries_[i].recorded_allocs, allocs) && - entries_[i].command_buffer) { - return true; - } - } - return false; -} - //===----------------------------------------------------------------------===// // TracedCommandBufferCmd //===----------------------------------------------------------------------===// @@ -2140,101 +2122,31 @@ absl::Status CollectiveCmd::Prepare(const Thunk::PrepareParams& params) { return params.clique_requests->RequestClique(clique_key); } -namespace { - -struct CollectiveTraceCacheKey { - GpuCliqueKey clique_key; - const CollectiveCmd* cmd; - - template - friend H AbslHashValue(H h, const CollectiveTraceCacheKey& k) { - return H::combine(std::move(h), k.clique_key, k.cmd); - } - - friend bool operator==(const CollectiveTraceCacheKey& a, - const CollectiveTraceCacheKey& b) { - return a.clique_key == b.clique_key && a.cmd == b.cmd; - } -}; - -} // namespace - absl::StatusOr CollectiveCmd::RecordTracedCommand( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer, - absl::FunctionRef trace, - const GpuCliqueKey& clique_key) { - - se::CommandBuffer* nested_cmd_ptr = nullptr; - std::unique_ptr nested_cmd_owned; - - if (clique_key.is_local()) { - auto traced_cmd = record_params.state.GetOrCreate( - this, command_buffer, [&] { - const auto& debug_options = xla::GetDebugOptionsFromFlags(); - return std::make_unique( - this, buffers(), - debug_options.xla_cmd_buffer_trace_cache_size()); - }); - - bool local_hit = traced_cmd->HasEntry(execute_params.buffer_allocations); - - CollectiveTraceCacheKey rendezvous_key{clique_key, this}; - TF_ASSIGN_OR_RETURN( - std::shared_ptr all_hit, - xla::Rendezvous( - "collective_trace_cache", rendezvous_key, local_hit, - clique_key.num_local_participants(), - [](absl::Span votes) { - return std::all_of(votes.begin(), votes.end(), - [](const bool* v) { return *v; }); - }, - /*warn_stuck_timeout=*/absl::Seconds(10), - /*terminate_timeout=*/absl::Seconds(30))); - - if (*all_hit) { - VLOG(5) << "Collective trace cache: all ranks hit, using cached graph"; - TF_ASSIGN_OR_RETURN( - nested_cmd_ptr, - traced_cmd->GetOrTraceCommandBuffer( - execute_params.buffer_allocations, - execute_params.stream->parent(), - execute_params.command_buffer_trace_stream, trace, priority())); - } else { - VLOG(5) << "Collective trace cache: not all ranks hit, all retracing"; - TF_ASSIGN_OR_RETURN( - nested_cmd_owned, - se::TraceCommandBufferFactory::Create( - execute_params.stream->parent(), - execute_params.command_buffer_trace_stream, trace)); - nested_cmd_ptr = nested_cmd_owned.get(); - } - } else { - TF_ASSIGN_OR_RETURN( - nested_cmd_owned, - se::TraceCommandBufferFactory::Create( - execute_params.stream->parent(), - execute_params.command_buffer_trace_stream, trace)); - nested_cmd_ptr = nested_cmd_owned.get(); - } + absl::FunctionRef trace) { + TF_ASSIGN_OR_RETURN(std::unique_ptr nested_cmd, + se::TraceCommandBufferFactory::Create( + execute_params.stream->parent(), + execute_params.command_buffer_trace_stream, trace)); if (priority() != se::StreamPriority::Default) { - TF_RETURN_IF_ERROR(nested_cmd_ptr->SetPriority(priority())); + TF_RETURN_IF_ERROR(nested_cmd->SetPriority(priority())); } return Handle( std::move(record_action), [&](absl::Span dependencies) { return command_buffer->CreateChildCommand( - se::CommandBuffer::ChildCommandType::kCloned, *nested_cmd_ptr, + se::CommandBuffer::ChildCommandType::kCloned, *nested_cmd, dependencies); }, [&](const se::CommandBuffer::Command* command) { return command_buffer->UpdateChildCommand( - se::CommandBuffer::ChildCommandType::kCloned, command, - *nested_cmd_ptr); + se::CommandBuffer::ChildCommandType::kCloned, command, *nested_cmd); }); } @@ -2294,8 +2206,7 @@ absl::StatusOr AllReduceCmd::Record( [&](se::Stream* stream) { return RunAllReduce(reduction_kind_, device_buffers, *stream, *comm, config().use_symmetric_buffer); - }, - clique_key); + }); } CommandBufferCmd::BufferUseVector AllReduceCmd::buffers() const { @@ -2363,8 +2274,7 @@ absl::StatusOr ReduceScatterCmd::Record( return RunReduceScatter( reduction_kind_, device_buffers, *stream, *comm, config().use_symmetric_buffer); - }, - clique_key); + }); } CommandBufferCmd::BufferUseVector ReduceScatterCmd::buffers() const { @@ -2433,8 +2343,7 @@ absl::StatusOr AllToAllCmd::Record( [&](se::Stream* stream) { return RunAllToAll(has_split_dimension_, device_buffers, *stream, *comm, config().use_symmetric_buffer); - }, - clique_key); + }); } CommandBufferCmd::BufferUseVector AllToAllCmd::buffers() const { @@ -2499,8 +2408,7 @@ absl::StatusOr AllGatherCmd::Record( [&](se::Stream* stream) { return RunAllGather(device_buffers, *stream, *comm, config().use_symmetric_buffer); - }, - clique_key); + }); } CommandBufferCmd::BufferUseVector AllGatherCmd::buffers() const { @@ -2565,8 +2473,7 @@ CollectiveBroadcastCmd::Record(const Thunk::ExecuteParams& execute_params, execute_params, record_params, std::move(record_action), command_buffer, [&](se::Stream* stream) { return RunCollectiveBroadcast(device_buffers, *stream, *comm); - }, - clique_key); + }); } CommandBufferCmd::BufferUseVector CollectiveBroadcastCmd::buffers() const { @@ -2648,8 +2555,7 @@ absl::StatusOr CollectivePermuteCmd::Record( /*use_memcpy=*/false, /*recv_ptr_map=*/nullptr, use_symmetric_buffer); - }, - clique_key); + }); } CommandBufferCmd::BufferUseVector CollectivePermuteCmd::buffers() const { diff --git a/xla/backends/gpu/runtime/command_buffer_cmd.h b/xla/backends/gpu/runtime/command_buffer_cmd.h index ad8e9a3aab066..0b666803dcf05 100644 --- a/xla/backends/gpu/runtime/command_buffer_cmd.h +++ b/xla/backends/gpu/runtime/command_buffer_cmd.h @@ -37,7 +37,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "xla/backends/gpu/collectives/gpu_clique_key.h" #include "xla/backends/gpu/runtime/collective_permute_thunk.h" #include "xla/backends/gpu/runtime/collective_thunk.h" #include "xla/backends/gpu/runtime/copy_thunk.h" @@ -612,8 +611,6 @@ class TracedCommandBuffer : public CommandBufferCmd::State { se::Stream* stream, absl::FunctionRef trace, se::StreamPriority priority = se::StreamPriority::Default); - bool HasEntry(const BufferAllocations* buffer_allocation) const; - private: std::vector allocs_indices_; @@ -1120,16 +1117,13 @@ class CollectiveCmd : public CommandBufferCmd { bool requires_initialization() override { return true; } - bool force_update() override { return true; } - bool IsNestedCommandBuffer() const final { return true; } absl::StatusOr RecordTracedCommand( const Thunk::ExecuteParams& execute_params, const RecordParams& record_params, RecordAction record_action, se::CommandBuffer* command_buffer, - absl::FunctionRef trace, - const GpuCliqueKey& clique_key); + absl::FunctionRef trace); bool IsAsync() const { return async_events_ != nullptr; } std::shared_ptr async_events() const { diff --git a/xla/backends/gpu/runtime/command_buffer_cmd_test.cc b/xla/backends/gpu/runtime/command_buffer_cmd_test.cc index 4a60b448fbdb5..c995a45d181ca 100644 --- a/xla/backends/gpu/runtime/command_buffer_cmd_test.cc +++ b/xla/backends/gpu/runtime/command_buffer_cmd_test.cc @@ -903,120 +903,4 @@ static void BM_GetOrTraceCommandBuffer(benchmark::State& state) { BENCHMARK(BM_GetOrTraceCommandBuffer); -TEST(TracedCommandBuffer, HasEntry) { - se::StreamExecutor* executor = GpuExecutor(); - auto stream = executor->CreateStream().value(); - auto traced_cmd = FakeCmd(); - - BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); - BufferAllocation alloc1(/*index=*/1, /*size=*/1024, /*color=*/0); - - CommandBufferCmd::BufferUseVector buffers = { - BufferUse::Read(BufferAllocation::Slice(&alloc0, 0, 1024)), - BufferUse::Write(BufferAllocation::Slice(&alloc1, 0, 1024))}; - - TracedCommandBuffer traced_cmd_buffer(&traced_cmd, buffers, - /*capacity=*/4); - - se::DeviceAddressBase mem0(reinterpret_cast(0x01234567)); - se::DeviceAddressBase mem1(reinterpret_cast(0x12345670)); - se::DeviceAddressBase mem2(reinterpret_cast(0x23456701)); - - se::StreamExecutorMemoryAllocator allocator(executor); - BufferAllocations allocations({mem0, mem1}, 0, &allocator); - - // Empty cache should report no entry. - EXPECT_FALSE(traced_cmd_buffer.HasEntry(&allocations)); - - // Trace a command buffer for {mem0, mem1}. - se::DeviceAddress mem = executor->AllocateArray(16, 0); - auto trace = [&](se::Stream* stream) -> absl::Status { - TF_RETURN_IF_ERROR(stream->Memset32(&mem, 42, 16)); - return absl::OkStatus(); - }; - - TF_ASSERT_OK(traced_cmd_buffer - .GetOrTraceCommandBuffer(&allocations, executor, - stream.get(), trace) - .status()); - - // Now HasEntry should find {mem0, mem1}. - EXPECT_TRUE(traced_cmd_buffer.HasEntry(&allocations)); - - // Different addresses should not be found. - BufferAllocations different_allocs({mem0, mem2}, 0, &allocator); - EXPECT_FALSE(traced_cmd_buffer.HasEntry(&different_allocs)); - - // Trace for {mem0, mem2} and verify both entries exist. - TF_ASSERT_OK(traced_cmd_buffer - .GetOrTraceCommandBuffer(&different_allocs, executor, - stream.get(), trace) - .status()); - - EXPECT_TRUE(traced_cmd_buffer.HasEntry(&allocations)); - EXPECT_TRUE(traced_cmd_buffer.HasEntry(&different_allocs)); -} - -TEST(TracedCommandBuffer, HasEntryDoesNotModifyCache) { - se::StreamExecutor* executor = GpuExecutor(); - auto stream = executor->CreateStream().value(); - auto traced_cmd = FakeCmd(); - - BufferAllocation alloc0(/*index=*/0, /*size=*/1024, /*color=*/0); - - CommandBufferCmd::BufferUseVector buffers = { - BufferUse::Read(BufferAllocation::Slice(&alloc0, 0, 1024))}; - - TracedCommandBuffer traced_cmd_buffer(&traced_cmd, buffers, - /*capacity=*/2); - - se::DeviceAddressBase mem0(reinterpret_cast(0x01234567)); - se::DeviceAddressBase mem1(reinterpret_cast(0x12345670)); - se::DeviceAddressBase mem2(reinterpret_cast(0x23456701)); - - se::StreamExecutorMemoryAllocator allocator(executor); - - se::DeviceAddress mem = executor->AllocateArray(16, 0); - int64_t num_traces = 0; - auto trace = [&](se::Stream* stream) -> absl::Status { - TF_RETURN_IF_ERROR(stream->Memset32(&mem, 42, 16)); - num_traces++; - return absl::OkStatus(); - }; - - // Fill cache with {mem0} and {mem1}. - BufferAllocations allocs0({mem0}, 0, &allocator); - BufferAllocations allocs1({mem1}, 0, &allocator); - BufferAllocations allocs2({mem2}, 0, &allocator); - - TF_ASSERT_OK(traced_cmd_buffer - .GetOrTraceCommandBuffer(&allocs0, executor, stream.get(), - trace) - .status()); - TF_ASSERT_OK(traced_cmd_buffer - .GetOrTraceCommandBuffer(&allocs1, executor, stream.get(), - trace) - .status()); - EXPECT_EQ(num_traces, 2); - - // HasEntry should not affect the LRU order -- calling it many times - // for allocs0 should not evict allocs1. - for (int i = 0; i < 10; i++) { - EXPECT_TRUE(traced_cmd_buffer.HasEntry(&allocs0)); - EXPECT_TRUE(traced_cmd_buffer.HasEntry(&allocs1)); - EXPECT_FALSE(traced_cmd_buffer.HasEntry(&allocs2)); - } - - // Both should still be in cache (no re-trace needed). - TF_ASSERT_OK(traced_cmd_buffer - .GetOrTraceCommandBuffer(&allocs0, executor, stream.get(), - trace) - .status()); - TF_ASSERT_OK(traced_cmd_buffer - .GetOrTraceCommandBuffer(&allocs1, executor, stream.get(), - trace) - .status()); - EXPECT_EQ(num_traces, 2); -} - } // namespace xla::gpu