Skip to content

Commit 3dd3158

Browse files
committed
Update on "Use caching allocator for runner (#15730)"
Summary: We observed that on iOS it improves perf by 6% because SDPA op does temp allocations. No significant difference on android though. ghstack-source-id: 328001114 exported-using-ghexport Reviewed By: navsud, derekdixu Differential Revision: D86120038 [ghstack-poisoned]
2 parents 5b9bf5e + 42830ac commit 3dd3158

14 files changed

Lines changed: 226 additions & 65 deletions

File tree

backends/cadence/aot/memory_constraints.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,45 @@ def is_cat_along_outermost_dim(
452452
return False
453453
return True
454454

455+
def _has_duplicate_resolved_sources(
456+
self, cat_tensors: Sequence[torch.fx.Node]
457+
) -> bool:
458+
"""Return True if two cat inputs resolve to the same underlying tensor."""
459+
if len(cat_tensors) != len(set(cat_tensors)):
460+
return True
461+
resolved_sources = set()
462+
for arg in cat_tensors:
463+
resolved = arg
464+
while (
465+
info := self.constraint.get_relative_placement_source(resolved)
466+
) is not None:
467+
if self.constraint.is_alias_of(info.source, resolved):
468+
resolved = info.source
469+
else:
470+
break
471+
if id(resolved) in resolved_sources:
472+
return True
473+
resolved_sources.add(id(resolved))
474+
return False
475+
476+
def _has_unaligned_cat_tensors(
477+
self,
478+
graph: torch.fx.Graph,
479+
node: torch.fx.Node,
480+
cat_tensors: Sequence[torch.fx.Node],
481+
) -> bool:
482+
"""Return True if any non-placeholder cat tensor has misaligned offset."""
483+
if is_node_in_flattened_output(graph, node):
484+
return False
485+
expected_alignment = 8
486+
relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors)
487+
for idx, arg in enumerate(cat_tensors):
488+
if not (arg.op == "placeholder") and (
489+
relative_offsets[idx] & (expected_alignment - 1) != 0
490+
):
491+
return True
492+
return False
493+
455494
# If A = cat(B, C), and the concatenation is along the outermost dimension, then
456495
# we can optimize away this cat operation if (1) B and C are placed contiguously,
457496
# and (2) the absolute memory location of tensor A is the same as B. This function
@@ -486,21 +525,17 @@ def is_removable_cat_op(
486525
return False
487526
# If the same tensor appears multiple times in the cat inputs,
488527
# we cannot place it at multiple different offsets relative to the output.
489-
if len(cat_tensors) != len(set(cat_tensors)):
528+
# Also check resolved sources: two different alias nodes may resolve to
529+
# the same underlying tensor, which can't be at two offsets.
530+
if self._has_duplicate_resolved_sources(cat_tensors):
490531
return False
491532

492533
# Many ops in HiFi require the input to be aligned to 8-byte boundary.
493534
# If the cat is not the graph's output, then ensure that the relative
494535
# offset of any concatenated non-placeholder tensor is a multiple of
495536
# 8 bytes,
496-
if not is_node_in_flattened_output(graph_module.graph, node):
497-
expected_alignment = 8
498-
relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors)
499-
for idx, arg in enumerate(cat_tensors):
500-
if not (arg.op == "placeholder") and (
501-
relative_offsets[idx] & (expected_alignment - 1) != 0
502-
):
503-
return False
537+
if self._has_unaligned_cat_tensors(graph_module.graph, node, cat_tensors):
538+
return False
504539

505540
return True
506541

backends/cadence/hifi/operators/op_permute_copy.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ Tensor& permute_copy_out(
7373

7474
bool optimized = false;
7575

76-
if (out.scalar_type() == ScalarType::Float ||
77-
out.scalar_type() == ScalarType::Char ||
76+
if (out.scalar_type() == ScalarType::Char ||
7877
out.scalar_type() == ScalarType::Byte)
7978
optimized = true;
8079

@@ -101,22 +100,7 @@ Tensor& permute_copy_out(
101100
p_permute_vec[i] = dims[i];
102101
}
103102

104-
if (in_type == ScalarType::Float) {
105-
WORD32* p_inp = (WORD32*)in.const_data_ptr<float>();
106-
WORD32* p_out = (WORD32*)out.mutable_data_ptr<float>();
107-
108-
WORD32 ret_val = xa_nn_transpose_32_32(
109-
p_out,
110-
p_out_shape,
111-
p_inp,
112-
p_inp_shape,
113-
p_permute_vec,
114-
num_out_dims,
115-
num_inp_dims);
116-
117-
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
118-
119-
} else if (in_type == ScalarType::Char) {
103+
if (in_type == ScalarType::Char) {
120104
WORD8* p_inp = (WORD8*)in.const_data_ptr<char>();
121105
WORD8* p_out = (WORD8*)out.mutable_data_ptr<char>();
122106

backends/cadence/hifi/operators/op_softmax.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ Tensor& _softmax_out(
6868
if (in.dim() > kNnlibMaxDim)
6969
optimized = false;
7070

71+
if (dim < in.dim() - 1)
72+
optimized = false;
73+
7174
if (optimized) {
7275
int* p_inp = (int*)in.const_data_ptr<float>();
7376
int* out_data = (int*)out.mutable_data_ptr<float>();

backends/cadence/hifi/operators/op_where.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ Tensor& where_self_out(
8181
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
8282
optimized = 0;
8383

84+
if (cond_is_broadcasted)
85+
optimized = 0;
86+
8487
if (optimized) {
8588
const float* a_data = a.const_data_ptr<float>();
8689
const float* b_data = b.const_data_ptr<float>();

backends/cadence/utils/facto_util.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
249249
case "permute_copy.default":
250250
tensor_constraints.extend(
251251
[
252-
cp.Dtype.In(lambda deps: [torch.float32, torch.int8, torch.uint8]),
252+
cp.Dtype.In(lambda deps: [torch.float32, torch.int32]),
253253
cp.Rank.Le(
254254
lambda deps: 5
255255
), # xa_nn_transpose only supports up to 5D
@@ -391,12 +391,13 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
391391
tensor_constraints.extend(
392392
[
393393
cp.Dtype.In(lambda deps: [torch.float32, torch.int32]),
394+
cp.Value.Ge(lambda deps, dtype, struct: 0),
394395
]
395396
)
396397
case "div.Tensor_mode" | "minimum.default":
397398
if index == 0:
398399
tensor_constraints = [
399-
cp.Dtype.In(lambda deps: [torch.int64, torch.int32, torch.float32]),
400+
cp.Dtype.In(lambda deps: [torch.int32, torch.float32]),
400401
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
401402
cp.Value.Le(lambda deps, dtype, struct: 2**4),
402403
cp.Rank.Ge(lambda deps: 1),
@@ -405,7 +406,7 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
405406
]
406407
else:
407408
tensor_constraints = [
408-
cp.Dtype.In(lambda deps: [torch.int64, torch.int32, torch.float32]),
409+
cp.Dtype.In(lambda deps: [torch.int32, torch.float32]),
409410
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
410411
cp.Value.Le(lambda deps, dtype, struct: 2**4),
411412
cp.Value.Ne(

backends/mlx/ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,32 @@ def handler(P: MLXProgramBuilder, n: Node) -> Slot:
418418
REGISTRY.register(target=[_target])(_make_unary_handler(_node_cls, _op_name))
419419

420420

421+
# ---------------------------------------------------------------------------
422+
# Numerical checks
423+
# ---------------------------------------------------------------------------
424+
425+
426+
@REGISTRY.register(target=[torch.ops.aten.isnan.default])
427+
def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot:
428+
"""Handle aten.isnan - check for NaN values element-wise.
429+
430+
isnan(x) is equivalent to x != x (NaN is the only value not equal to itself).
431+
"""
432+
args = P.args(n)
433+
require_args(args, 1, 1, "aten.isnan")
434+
require_kwargs(P.kwargs(n), set(), "aten.isnan")
435+
x = args[0]
436+
out = P.make_or_get_slot(n)
437+
P.emit(
438+
NotEqualNode(
439+
a=P.slot_to_tid(x),
440+
b=P.slot_to_tid(x),
441+
out=P.slot_to_tid(out),
442+
)
443+
)
444+
return out
445+
446+
421447
_BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [
422448
(
423449
[torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar],

backends/mlx/test/test_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4004,6 +4004,22 @@ def fn(shape, dtype):
40044004
return fn
40054005

40064006

4007+
def _nan_input_fn(nan_frac: float = 0.3):
4008+
"""Return a callable(shape, dtype) that generates inputs with some NaN values.
4009+
4010+
Args:
4011+
nan_frac: Fraction of elements to set to NaN (default 0.3 = 30%).
4012+
"""
4013+
4014+
def fn(shape, dtype):
4015+
x = torch.randn(shape, dtype=dtype)
4016+
mask = torch.rand(shape) > (1.0 - nan_frac)
4017+
x[mask] = float("nan")
4018+
return (x,)
4019+
4020+
return fn
4021+
4022+
40074023
# Standard shape and dtype configs used by unary tests.
40084024
_SHAPES_3 = [(16,), (4, 4), (2, 3, 4)]
40094025
_SHAPES_2 = [(16,), (4, 4)]
@@ -4095,6 +4111,7 @@ def create_model(self) -> nn.Module:
40954111
{"op_name": "abs", "op_fn": torch.abs},
40964112
{"op_name": "neg", "op_fn": torch.neg},
40974113
{"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()},
4114+
{"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()},
40984115
# activations
40994116
{"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)},
41004117
{"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)},

backends/nxp/backend/ir/converter/conversion/translator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def numpy_type_to_tf_lite(numpy_type: np.dtype) -> TensorType: # noqa C901
601601
elif numpy_type == np.int64:
602602
return TensorType.INT64
603603

604-
elif numpy_type == np.string_:
604+
elif numpy_type == np.bytes_:
605605
return TensorType.STRING
606606

607607
elif numpy_type == np.bool_:
@@ -659,7 +659,7 @@ def tf_lite_type_to_numpy(tfl_type: TensorType) -> np.ScalarType: # noqa C901
659659
return np.dtype(np.int64)
660660

661661
elif tfl_type == TensorType.STRING:
662-
return np.dtype(np.string_)
662+
return np.dtype(np.bytes_)
663663

664664
elif tfl_type == TensorType.BOOL:
665665
return np.dtype(np.bool_)

devtools/scripts/BUCK

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
load("@fbcode_macros//build_defs:build_file_migration.bzl", "non_fbcode_target")
2+
3+
oncall("executorch")
4+
5+
non_fbcode_target(
6+
_kind = native.sh_binary,
7+
name = "_benchmark_android_sh",
8+
main = "benchmark_android.sh",
9+
)
10+
11+
non_fbcode_target(
12+
_kind = native.command_alias,
13+
name = "benchmark_android",
14+
exe = ":_benchmark_android_sh",
15+
args = ["--build-tool", "buck"],
16+
)

0 commit comments

Comments
 (0)