Skip to content

Commit 9b80420

Browse files
authored
Merge branch 'main' into moe-no-float
2 parents 08777af + f61350b commit 9b80420

32 files changed

Lines changed: 2051 additions & 522 deletions

backends/nxp/backend/neutron_converter_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def convert(
9292
)
9393
cctx.compilationOpts.fetchConstantsToSRAM = fetch_constants_to_sram
9494
cctx.compilationOpts.dumpKernelSelectionCode = self.dump_kernel_selection_code
95-
cctx.compilationOpts.useNewFlowNeutronC = use_new_flow_neutron_c
95+
if hasattr(cctx.compilationOpts, "useNewFlowNeutronC"):
96+
cctx.compilationOpts.useNewFlowNeutronC = use_new_flow_neutron_c
9697

9798
# Try to use multiprocessing for isolation, but fall back to direct execution
9899
# if the environment doesn't support it (e.g., in sandcastle/build environments)

backends/nxp/tests/test_neutron_converter_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import torch
99
from eiq_neutron_sdk.neutron_converter.neutron_converter import CompilationContext
10-
1110
from executorch import exir
1211
from executorch.backends.nxp.backend.edge_program_converter import (
1312
EdgeProgramToIRConverter,
@@ -72,4 +71,5 @@ def test_neutron_converter_with_experimental_mlir_flow(mocker):
7271

7372
compilation_context = process_spy.call_args.kwargs["args"][2]
7473
assert isinstance(compilation_context, CompilationContext)
75-
assert compilation_context.compilationOpts.useNewFlowNeutronC
74+
if hasattr(compilation_context.compilationOpts, "useNewFlowNeutronC"):
75+
assert compilation_context.compilationOpts.useNewFlowNeutronC

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import logging
88
import operator
9+
10+
from collections import deque
911
from typing import Any
1012

1113
import executorch.backends.vulkan.utils as utils
@@ -332,81 +334,111 @@ def trace_node_users_to_constrain_repset( # noqa: C901
332334
search_depth: list[int] | None = None,
333335
) -> utils.TensorRepSet:
334336
"""
335-
For an ambiguous repset, try to constrain the repset by tracing the required
336-
repsets of the users of `origin_node`. The idea is to try to find a representation
337-
that can be used the longest without needing user nodes to insert a transition
338-
for its arguments.
337+
BFS over downstream users to constrain an ambiguous repset. Explores all
338+
immediate users at each level before going deeper, so that nearby constrained
339+
ops (e.g. linear requiring width_packed) are discovered before the search
340+
budget is spent on a single deep branch.
339341
"""
340-
# Optionally limit the total number of nodes explored to improve export
341-
# time. search_depth is a mutable list so that all branches of a fan-out
342-
# share a single counter, preventing exponential blowup.
343342
if self.max_trace_search_depth is not None:
344343
if search_depth is None:
345344
search_depth = [self.max_trace_search_depth]
346-
search_depth[0] -= 1
347-
if search_depth[0] <= 0:
345+
346+
queue: deque[torch.fx.Node] = deque()
347+
queue.append(origin_node)
348+
349+
while queue:
350+
if repset.is_constrained():
348351
return repset
349352

350-
users_to_trace = origin_node.users
353+
if self.max_trace_search_depth is not None:
354+
search_depth[0] -= 1
355+
if search_depth[0] <= 0:
356+
return repset
357+
358+
node = queue.popleft()
359+
360+
users_to_trace = node.users
361+
362+
sync_outs_repr = True
363+
if self.is_valid_op_node(node):
364+
sync_outs_repr = self.get_node_cached_repsets(node).sync_outs_repr
351365

352-
sync_outs_repr = True
353-
if self.is_valid_op_node(origin_node):
354-
sync_outs_repr = self.get_node_cached_repsets(origin_node).sync_outs_repr
366+
if utils.num_tensors_in_node(node) > 1 and not sync_outs_repr:
367+
users_to_trace = []
368+
for usage_node in node.users:
369+
if (
370+
usage_node.target == operator.getitem
371+
and usage_node.args[1] == 1
372+
):
373+
users_to_trace.append(usage_node)
355374

356-
if utils.num_tensors_in_node(origin_node) > 1 and not sync_outs_repr:
357-
users_to_trace = []
358-
for usage_node in origin_node.users:
359-
if usage_node.target == operator.getitem and usage_node.args[1] == 1:
360-
users_to_trace.append(usage_node)
375+
for usage_node in users_to_trace:
376+
if repset.is_constrained():
377+
return repset
361378

362-
for usage_node in users_to_trace:
363-
arg_i_in_user = None
364-
for i in range(len(usage_node.args)):
365-
if origin_node == usage_node.args[i]:
366-
arg_i_in_user = i
367-
break
379+
arg_i_in_user = None
380+
for i in range(len(usage_node.args)):
381+
if node == usage_node.args[i]:
382+
arg_i_in_user = i
383+
break
368384

369-
if arg_i_in_user is not None:
370-
repset = self.constrain_repset_with_user(
371-
usage_node, arg_i_in_user, repset, search_depth
385+
if arg_i_in_user is None:
386+
continue
387+
388+
if not self.is_valid_op_node(usage_node):
389+
continue
390+
391+
cur_node_repsets = self.get_node_cached_repsets(usage_node)
392+
req_arg_repset = cur_node_repsets.get_arg_repset(arg_i_in_user)
393+
394+
if not req_arg_repset.any_in_common(repset):
395+
continue
396+
397+
repset = repset.make_intersect(req_arg_repset)
398+
399+
repset_propagates_to_output = (
400+
cur_node_repsets.sync_primary_io_repr
401+
and (
402+
cur_node_repsets.sync_args_repr
403+
or arg_i_in_user == cur_node_repsets.primary_arg_idx
404+
)
372405
)
373406

374-
if repset.is_constrained():
375-
return repset
407+
if repset_propagates_to_output:
408+
queue.append(usage_node)
376409

377410
return repset
378411

379412
def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> None:
380413
"""
381414
Attempts to constrain the repset of the argument at index `arg_i` of the op
382-
associated with `op_repsets`. Does this with two stages:
383-
384-
1. First, account for any existing representation that has already been determined
385-
for the argument. If no existing representation has been determined, then use
386-
the output repset of the operator that produces the argument.
387-
2. Then, try to trace through the users of the argument to find a representation
388-
that can be used for as long as possible without needing a transition.
415+
associated with `op_repsets`. Prefers downstream consumers' layout requirements
416+
over the upstream source's existing layout, falling back to the source only when
417+
downstream tracing does not fully constrain the repset.
389418
"""
390-
# If forcing fp16, then try to use texture storage whenever possible. This is
391-
# a temporary stopgap measure until all buffer implementations properly account
392-
# for potential overflow of fp16 representation range when doing math in fp16.
393419
if self.force_fp16:
394420
op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE)
395421

396-
arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i)
397-
op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset)
398-
399-
arg_repset = op_repsets.get_arg_repset(arg_i)
400-
if arg_repset.is_constrained():
401-
return
402-
422+
# First, trace downstream users to discover what layout they prefer.
403423
arg_node = op_repsets.op_node.args[arg_i]
404-
405424
if isinstance(arg_node, list):
406425
arg_node = arg_node[0]
407426

408-
arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset)
409-
op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset)
427+
arg_repset = op_repsets.get_arg_repset(arg_i)
428+
if not arg_repset.is_constrained():
429+
downstream_repset = self.trace_node_users_to_constrain_repset(
430+
arg_node, arg_repset
431+
)
432+
op_repsets.try_constrain_with_arg_repset(arg_i, downstream_repset)
433+
434+
# Fall back to the upstream source's existing layout only if downstream
435+
# tracing did not fully constrain the repset.
436+
arg_repset = op_repsets.get_arg_repset(arg_i)
437+
if not arg_repset.is_constrained():
438+
arg_source_repset = self.get_arg_tensor_source_repset(
439+
op_repsets.op_node, arg_i
440+
)
441+
op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset)
410442

411443
def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None:
412444
"""

backends/vulkan/custom_ops_lib.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,55 @@ def apply_rotary_emb_hf_impl(
828828
lib.impl(name, apply_rotary_emb_hf_impl, "CompositeExplicitAutograd")
829829
apply_rotary_emb_hf_op = getattr(getattr(torch.ops, namespace), name)
830830

831+
##################################
832+
## apply_rotary_emb_interleaved ##
833+
##################################
834+
835+
836+
def apply_rotary_emb_interleaved_impl(
837+
x: torch.Tensor, freqs_cis: torch.Tensor
838+
) -> torch.Tensor:
839+
# EdgeTAM's pair-interleaved complex-number RoPE.
840+
# x: [B, N, C] with (real, imag) pairs interleaved along C
841+
# freqs_cis: any rank whose flattened layout is [N, C]. Commonly 2D
842+
# [N, C] or 4D [1, N, C/2, 2] from
843+
# `torch.view_as_real(...).unsqueeze(0)`. The (cos, sin)
844+
# pairs are interleaved along the innermost axis in the
845+
# flattened view.
846+
# Semantically equivalent to:
847+
# freqs_cis.reshape(N, C // 2, 2) -> (cos, sin)
848+
# out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k]
849+
# out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k]
850+
B, N, C = x.shape
851+
a_real, a_imag = x.view(B, N, C // 2, 2).unbind(-1)
852+
# Use reshape so callers may pass freqs_cis at any rank.
853+
cs = freqs_cis.reshape(N, C // 2, 2)
854+
b_real, b_imag = cs[..., 0], cs[..., 1]
855+
out = torch.stack(
856+
(a_real * b_real - a_imag * b_imag, a_real * b_imag + a_imag * b_real),
857+
dim=-1,
858+
)
859+
return out.view(B, N, C)
860+
861+
862+
def apply_rotary_emb_interleaved_meta(
863+
x: torch.Tensor, freqs_cis: torch.Tensor
864+
) -> torch.Tensor:
865+
# Meta kernel: shape-only. Keeps the op opaque during torch.export (no
866+
# inlining of view/reshape calls into the exported graph) and does not
867+
# constrain the rank of freqs_cis — any shape with N * C elements is
868+
# accepted by the Vulkan dispatcher.
869+
return torch.empty_like(x)
870+
871+
872+
name = "apply_rotary_emb_interleaved"
873+
lib.define(f"{name}(Tensor x, Tensor freqs_cis) -> Tensor")
874+
# CPU kernel preserves eager-mode reference semantics.
875+
lib.impl(name, apply_rotary_emb_interleaved_impl, "CPU")
876+
# Meta kernel keeps the op opaque in the exported graph.
877+
lib.impl(name, apply_rotary_emb_interleaved_meta, "Meta")
878+
apply_rotary_emb_interleaved_op = getattr(getattr(torch.ops, namespace), name)
879+
831880
########################
832881
## q8ta_add ##
833882
########################
@@ -960,6 +1009,34 @@ def select_as_symint_impl(x: torch.Tensor, dim: int, index: int):
9601009
lib.impl(name, select_as_symint_impl, "Meta")
9611010
select_as_symint_op = getattr(getattr(torch.ops, namespace), name)
9621011

1012+
##########
1013+
## sdpa ##
1014+
##########
1015+
1016+
1017+
def sdpa_impl(
1018+
q: torch.Tensor,
1019+
k: torch.Tensor,
1020+
v: torch.Tensor,
1021+
attn_mask: Optional[torch.Tensor] = None,
1022+
scale: Optional[float] = None,
1023+
):
1024+
if scale is None:
1025+
scale = 1.0 / (q.size(-1) ** 0.5)
1026+
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
1027+
if attn_mask is not None:
1028+
attn = attn + attn_mask
1029+
attn = torch.softmax(attn, dim=-1)
1030+
return torch.matmul(attn, v)
1031+
1032+
1033+
name = "sdpa"
1034+
lib.define(
1035+
f"{name}(Tensor q, Tensor k, Tensor v, Tensor? attn_mask = None, float? scale = None) -> Tensor"
1036+
)
1037+
lib.impl(name, sdpa_impl, "CompositeExplicitAutograd")
1038+
sdpa_op = getattr(getattr(torch.ops, namespace), name)
1039+
9631040
################
9641041
## rms_norm ##
9651042
################

backends/vulkan/op_registry.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,20 @@ def register_sdpa_cpp_ops():
10711071
)
10721072

10731073

1074+
# =============================================================================
1075+
# SDPA.cpp (fused SDPA entry point)
1076+
# =============================================================================
1077+
1078+
1079+
@update_features("et_vk::sdpa")
1080+
def register_general_sdpa():
1081+
return OpFeatures(
1082+
inputs_storage=utils.CONTIGUOUS_ANY,
1083+
inputs_dtypes=utils.FP_T,
1084+
supports_resize=True,
1085+
)
1086+
1087+
10741088
# =============================================================================
10751089
# RotaryEmbedding.cpp
10761090
# =============================================================================
@@ -1096,6 +1110,22 @@ def register_apply_rotary_emb_hf():
10961110
)
10971111

10981112

1113+
@update_features(exir_ops.edge.et_vk.apply_rotary_emb_interleaved.default)
1114+
def register_apply_rotary_emb_interleaved():
1115+
return OpFeatures(
1116+
# freqs_cis is pinned to buffer storage so the shader can compute a
1117+
# flat [N, C] linear address regardless of the tensor's declared rank
1118+
# (callers commonly pass 4D [1, N, C/2, 2] without a preceding view).
1119+
inputs_storage=[
1120+
utils.CONTIGUOUS_ANY, # x
1121+
utils.CONTIGUOUS_BUFFER, # freqs_cis
1122+
],
1123+
inputs_dtypes=utils.FP_T,
1124+
supports_resize=True,
1125+
supports_highdim=True,
1126+
)
1127+
1128+
10991129
# =============================================================================
11001130
# Permute.cpp
11011131
# =============================================================================

0 commit comments

Comments
 (0)