Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3896,6 +3896,8 @@ cc_library(
"//xla:util",
"//xla/ffi:ffi_api",
"//xla/hlo/ir:hlo",
"//xla/runtime:buffer_use",
"//xla/service:buffer_assignment",
"//xla/stream_executor:device_description",
"//xla/stream_executor:semantic_version",
"//xla/tsl/platform:errors",
Expand Down
126 changes: 126 additions & 0 deletions xla/backends/gpu/runtime/command_buffer_conversion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ limitations under the License.
#include "xla/backends/gpu/runtime/while_thunk.h"
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/runtime/buffer_use.h"
#include "xla/service/buffer_assignment.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/semantic_version.h"
#include "xla/tsl/platform/errors.h"
Expand Down Expand Up @@ -477,6 +479,71 @@ std::string CommandBufferConversionPass::CommandBufferConfig::ToString() const {
return absl::StrCat("enabled_commands: [", cmd_names, "]");
}

// Configuration for the live-out alias-split logic in
// CommandBufferConversionPass::Run().
//
// XLA_GPU_CMDBUF_SPLIT_ON_LIVE_OUT_ALIAS = "1" (default) | "0"
//
// When enabled, the conversion pass tracks the set of `maybe_live_out`
// BufferAllocation::Index values that the current command-buffer chunk
// has already written. Before adding a new thunk to the chunk, if any
// of the thunk's write targets would re-write a `maybe_live_out` index
// already in the set, the current chunk is flushed first and the new
// thunk starts a fresh chunk. Each command buffer then writes any
// given `maybe_live_out` allocation by at most one thunk, preserving
// HLO last-writer-wins semantics for output allocations across
// command-buffer boundaries.
//
// Background: HLO buffer assignment may alias multiple writers to the
// same allocation when that allocation is `maybe_live_out`, because
// under the serial HLO schedule only the last writer's value is
// observable. When all such writers are captured into a single
// command buffer, the buffer holds the aliased virtual address for
// the entire graph duration, which can let consumers observe
// intermediate alias content instead of the HLO-defined last-writer
// value. Splitting the chunk inserts a graph boundary between
// writers so each output allocation has at most one writer per graph,
// restoring the expected publication semantics through the standard
// inter-graph stream barrier.
class LiveOutAliasSplit {
public:
static const LiveOutAliasSplit& Get() {
static LiveOutAliasSplit* f = new LiveOutAliasSplit();
return *f;
}
bool enabled() const { return enabled_; }

private:
LiveOutAliasSplit() {
const char* e = std::getenv("XLA_GPU_CMDBUF_SPLIT_ON_LIVE_OUT_ALIAS");
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: std::getenv requires <cstdlib>, which doesn't appear in this file's includes. It likely compiles through transitive headers, but an explicit include would be more correct and resilient.

Also — the rest of the command buffer configuration in this file uses DebugOptions proto fields (xla_gpu_enable_command_buffer, xla_gpu_command_buffer_scheduling_mode, etc.). Using a raw env var instead makes this flag invisible to XLA's flag introspection, documentation, and testing infrastructure. If this is a temporary escape hatch before upstreaming, that's understandable, but ideally this would become a DebugOptions field for consistency.

if (e == nullptr || e[0] == '\0') {
// Default: enabled. Set the env to "0" to disable.
enabled_ = true;
return;
}
enabled_ = (e[0] != '0');
}
bool enabled_ = true;
};

// Returns the set of `maybe_live_out` BufferAllocation::Index values
// that `thunk` writes to. Returns an empty set when none of the
// thunk's writes target a `maybe_live_out` allocation. Used by the
// live-out alias-split logic in Run() to decide whether two thunks in
// the same command-buffer chunk would rewrite the same allocation.
absl::flat_hash_set<BufferAllocation::Index> CollectLiveOutWriteIndices(
const Thunk& thunk) {
absl::flat_hash_set<BufferAllocation::Index> out;
for (const BufferUse& use : thunk.buffer_uses()) {
if (use.access() == BufferUse::MemoryAccess::kRead) continue;
const BufferAllocation* alloc = use.slice().allocation();
if (alloc == nullptr) continue;
if (!alloc->maybe_live_out()) continue;
out.insert(use.slice().index());
}
return out;
}

absl::StatusOr<bool> CommandBufferConversionPass::Run(
SequentialThunk* root_thunk, const DebugOptions& debug_options,
const HloModule* absl_nullable hlo_module,
Comment on lines +509 to 549
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Both LiveOutAliasSplit and CollectLiveOutWriteIndices are defined outside the anonymous namespace (which closes at line 471). All other file-local helpers in this file (GetCommandBufferConfig, GetCommandBufferCmdType, IsConvertible, CheckAsyncRegion, FlushCommandBuffer) are inside the anonymous namespace. These two new declarations should be placed in the anonymous namespace to follow the file's convention and avoid giving CollectLiveOutWriteIndices external linkage in the xla::gpu namespace.

Expand All @@ -494,16 +561,60 @@ absl::StatusOr<bool> CommandBufferConversionPass::Run(
debug_options.xla_gpu_command_buffer_scheduling_mode()));

bool changed = false;
const bool live_out_split_enabled = LiveOutAliasSplit::Get().enabled();

std::vector<std::unique_ptr<Thunk>> current_command_buffer_thunks;
// Tracks the `BufferAllocation::Index` of every `maybe_live_out`
// allocation written by some thunk already added to
// `current_command_buffer_thunks`. Reset whenever a chunk is
// flushed. See LiveOutAliasSplit for the rationale.
absl::flat_hash_set<BufferAllocation::Index> chunk_live_out_writes;
std::vector<std::unique_ptr<Thunk>> new_thunks;

auto flush_command_buffer = [&]() -> absl::Status {
chunk_live_out_writes.clear();
return FlushCommandBuffer(synchronization_mode, debug_options,
current_command_buffer_thunks, new_thunks,
changed);
};

// Returns true iff adding `incoming` (one thunk or a whole async
// region) to the current chunk would cause a `maybe_live_out`
// allocation to be written by both an earlier thunk in the chunk
// and one of these incoming thunks. When that would happen,
// callers must flush the current chunk first so the new thunks land
// in a fresh chunk; the publication of the last writer's value is
// then provided by the inter-graph stream barrier rather than by
// intra-graph ordering of aliased writes.
auto would_alias_live_out =
[&](absl::Span<const std::unique_ptr<Thunk>> incoming) {
if (!live_out_split_enabled) return false;
for (const auto& th : incoming) {
auto idxs = CollectLiveOutWriteIndices(*th);
for (BufferAllocation::Index i : idxs) {
if (chunk_live_out_writes.contains(i)) {
VLOG(2) << "Splitting command buffer at thunk '"
<< th->thunk_info().profile_annotation
<< "': would re-write maybe_live_out allocation #"
<< i << " already written earlier in this chunk";
return true;
}
}
}
return false;
};
Comment on lines +589 to +605
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: CollectLiveOutWriteIndices is called twice per thunk — once in would_alias_live_out and again in record_chunk_writes. Each call iterates all buffer_uses() and allocates a new flat_hash_set. The result could be computed once and reused:

Suggested change
auto would_alias_live_out =
[&](absl::Span<const std::unique_ptr<Thunk>> incoming) {
if (!live_out_split_enabled) return false;
for (const auto& th : incoming) {
auto idxs = CollectLiveOutWriteIndices(*th);
for (BufferAllocation::Index i : idxs) {
if (chunk_live_out_writes.contains(i)) {
VLOG(2) << "Splitting command buffer at thunk '"
<< th->thunk_info().profile_annotation
<< "': would re-write maybe_live_out allocation #"
<< i << " already written earlier in this chunk";
return true;
}
}
}
return false;
};
auto would_alias_live_out =
[&](absl::Span<const std::unique_ptr<Thunk>> incoming,
std::vector<absl::flat_hash_set<BufferAllocation::Index>>* out_idxs) {
out_idxs->clear();
if (!live_out_split_enabled) return false;
for (const auto& th : incoming) {
out_idxs->push_back(CollectLiveOutWriteIndices(*th));
for (BufferAllocation::Index i : out_idxs->back()) {
if (chunk_live_out_writes.contains(i)) {
VLOG(2) << "Splitting command buffer at thunk '"
<< th->thunk_info().profile_annotation
<< "': would re-write maybe_live_out allocation #"
<< i << " already written earlier in this chunk";
return true;
}
}
}
return false;
};

Then record_chunk_writes can consume the already-computed sets instead of recomputing them. This is a minor efficiency point — unlikely to be a bottleneck since command buffer conversion isn't on the hot path.


auto record_chunk_writes =
[&](absl::Span<const std::unique_ptr<Thunk>> incoming) {
if (!live_out_split_enabled) return;
for (const auto& th : incoming) {
auto idxs = CollectLiveOutWriteIndices(*th);
for (BufferAllocation::Index i : idxs) {
chunk_live_out_writes.insert(i);
}
}
};

auto& original_thunks = root_thunk->thunks();

for (size_t i = 0; i < original_thunks.size(); ++i) {
Expand All @@ -517,16 +628,31 @@ absl::StatusOr<bool> CommandBufferConversionPass::Run(
absl::MakeSpan(original_thunks).subspan(i), config);

if (!region.empty()) {
// Live-out alias-split applied to async regions atomically:
// either the whole region triggers a split, or none of it does.
// Async start/done pairs must remain in the same command buffer.
if (would_alias_live_out(region)) {
TF_RETURN_IF_ERROR(flush_command_buffer());
}
// If a valid region is found, add the whole region to the current
// sequence and continue processing.
i += region.size() - 1;
record_chunk_writes(region);
Comment on lines +635 to +640
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The would_alias_live_out lambda checks incoming thunks only against the existing chunk_live_out_writes set, not against each other within the incoming span. If two thunks within the same async region both write to the same maybe_live_out allocation, that wouldn't be detected. I believe this is intentional (async start/done pairs must stay in the same command buffer, so splitting within a region isn't possible), but a brief comment here would help future readers understand this is by design rather than an oversight.

absl::c_move(region, std::back_inserter(current_command_buffer_thunks));
continue;
}
} else if (IsConvertible(*thunk.get(), config) && !thunk->IsAsyncDone()) {
// If this thunk would re-write a `maybe_live_out` allocation
// already written earlier in the current chunk, close the chunk
// first. See LiveOutAliasSplit for the rationale.
auto incoming_view = absl::MakeConstSpan(&thunk, 1);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: absl::MakeConstSpan(&thunk, 1) works because thunk is a reference to an element in the original_thunks vector (contiguous storage). This is correct but relies on the container being a std::vector. A comment noting the contiguity assumption would help guard against future refactors that might change the container type.

if (would_alias_live_out(incoming_view)) {
TF_RETURN_IF_ERROR(flush_command_buffer());
}
// Check if thunk is convertible and not an async done: async done thunks
// can be only added to the current_command_buffer_thunks as part of a
// valid async regions.
record_chunk_writes(incoming_view);
current_command_buffer_thunks.push_back(std::move(thunk));
continue;
}
Expand Down
Loading