Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,42 @@ def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue:
self.chain.instructions.append(kernel)
return out_arg

def _emit_subview(self, args: Tuple[_Argument, ...]) -> _EmitterValue:
spec = self.node.meta["spec"]
is_static = spec.is_static_shape_tensor
is_memory_planned = (spec.mem_id is not None) and (spec.mem_offset is not None)
if is_static and is_memory_planned:
return self._emit_spec(spec)

out_arg = self._emit_argument(
self._emit_spec(self.node.meta["spec"]), torch.TensorType # pyre-ignore[6]
)

if self.node.target == memory.select:
self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6]
dim_arg = self._emit_argument(args[1], torch.IntType)
index_arg = self._emit_argument(args[2], torch.IntType)
op_idx, op = self._get_operator(
name="executorch_prim::et_select",
overload="default",
)
kernel = Instruction(
KernelCall(
op_idx,
args=[self_arg.id, dim_arg.id, index_arg.id, out_arg.id],
)
)
else:
raise InternalError(
self._emit_node_specific_error(
self.node,
f"Unsupported subview target: {self.node.target}",
)
)

self.chain.instructions.append(kernel)
return out_arg

def _add_debug_handle(
self,
emitter_id: int,
Expand Down Expand Up @@ -1758,6 +1794,9 @@ def call_function( # pyre-fixme[14]
elif target == memory.view:
return self._emit_view(args)

elif target == memory.select:
return self._emit_subview(args)

elif target == memory.free:
assert len(args) == 1
# pyre-ignore
Expand Down
9 changes: 9 additions & 0 deletions exir/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ def view(base: torch.Tensor, size: List[int]) -> torch.Tensor:
It is used to elide view_copy nodes.
"""
return base.view(size)


def select(base: torch.Tensor, dim: int, index: int) -> torch.Tensor:
"""
This function mimics torch.ops.aten.select.int.

It is used to elide select_copy nodes when the result is contiguous.
"""
return base.select(dim, index)
6 changes: 5 additions & 1 deletion exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def collect_specs_from_nodes( # noqa: C901
in [
memory.alloc,
memory.view,
memory.select,
operator.getitem,
torch.ops.higher_order.cond,
exir_while,
Expand Down Expand Up @@ -763,9 +764,12 @@ def get_node_tensor_specs(
has no tensor specs.
"""
# get tensor specs
if node.target == memory.view:
if node.target in (memory.view, memory.select):
base = node.args[0]
assert isinstance(base, torch.fx.Node)
while base.target in (memory.view, memory.select):
base = base.args[0]
assert isinstance(base, torch.fx.Node)
specs = base.meta.get("spec")
else:
specs = node.meta.get("spec")
Expand Down
1 change: 1 addition & 0 deletions exir/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None:
# it's retraced after running to_out_variant with the first trace.
memory.alloc,
memory.view,
memory.select,
executorch_call_delegate,
}
to_out_var_skiplist.update(_EXECUTORCH_SYM_OPS)
Expand Down
137 changes: 105 additions & 32 deletions exir/passes/replace_view_copy_with_view_pass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -33,7 +33,16 @@
)


def _is_select_copy(node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in (
torch.ops.aten.select_copy.int,
ops.edge.aten.select_copy.int,
)


_VIEW_OP = memory.view
_SELECT_OP = memory.select



class _Guard:
Expand All @@ -54,7 +63,12 @@


class _ViewSpec(TensorSpec):
def __init__(self, base: TensorSpec, shape: List[int]) -> None:
def __init__(
self,
base: TensorSpec,
shape: List[int],
byte_offset: int = 0,
) -> None:
"""
A _ViewSpec is TensorSpec that shares non-size related fields with its base.
The size-related fields are: shape, stride, dim_order, and shape_dynamism.
Expand All @@ -65,7 +79,11 @@

A _ViewSpec can only be created from a non-sparse, strided TensorSpec.
On creation, a _ViewSpec must be compatible with its base with respect to
shape_dynamism, dtype, and nbytes.
shape_dynamism, dtype, and nbytes (when byte_offset is 0).

When byte_offset is non-zero (used for select/slice sub-views), the view
describes a contiguous sub-region of the base at the given byte offset.
In this case, nbytes may differ from the base and rank may change.

A _ViewSpec contains _guards that are evaluated on every __getattribute__ call.
The purpose of the guards is to make sure the _ViewSpec is still compatible
Expand Down Expand Up @@ -119,6 +137,7 @@

self._guards: List[_Guard] = []
self._unguarded_access = False
self._byte_offset: int = byte_offset

# Make sure base is not sparse and add a guard
if base.is_sparse:
Expand Down Expand Up @@ -154,27 +173,6 @@
torch.Size(self.shape)
)

# Check compatibility with base on creation
if self.shape_dynamism != base.shape_dynamism:
raise Exception(
f"_ViewSpec is incompatible with its base on creation. It has shape_dynamism={self.shape_dynamism}, but its base has shape_dynamism={base.shape_dynamism}."
)
self._guards.append(
_Guard(
"shape_dynamism_init",
lambda view_spec: view_spec.shape_dynamism,
base.shape_dynamism,
)
)
self._guards.append(
_Guard(
"shape_dynamism_eq_base",
lambda view_spec: view_spec.shape_dynamism
== view_spec._base.shape_dynamism,
True,
)
)

if self.dtype != base.dtype:
raise Exception(
f"_ViewSpec is incompatible with its base on creation. It has dtype={self.dtype}, but its base has dtype={base.dtype}."
Expand All @@ -183,15 +181,32 @@
_Guard("dtype", lambda view_spec: view_spec.dtype, base.dtype)
)

# We do not guard nbytes because dynamic symints are replaced by upper bounds.
# We do guard on rank, though
if self.nbytes() != base.nbytes():
raise Exception(
f"_ViewSpec is incompatible with its base on creation. It has nbytes={self.nbytes()}, but its base has nbytes={base.nbytes()}."
# For full views (same nbytes, zero offset), dynamism and rank must match.
# For sub-views (select/slice), the output is a contiguous subset so
# nbytes will differ, rank may change, and dynamism may differ
# (e.g., selecting away a dynamic dim produces a static output).
is_full_view = byte_offset == 0 and self.nbytes() == base.nbytes()
if is_full_view:
if self.shape_dynamism != base.shape_dynamism:
raise Exception(
f"_ViewSpec is incompatible with its base on creation. It has shape_dynamism={self.shape_dynamism}, but its base has shape_dynamism={base.shape_dynamism}."
)
self._guards.append(
_Guard(
"shape_dynamism_eq_base",
lambda view_spec: view_spec.shape_dynamism
== view_spec._base.shape_dynamism,
True,
)
)
self._guards.append(
_Guard("rank", lambda view_spec: len(view_spec.shape), len(shape))
)
self._guards.append(
_Guard("rank", lambda view_spec: len(view_spec.shape), len(shape))
)
else:
if self.nbytes() + byte_offset > base.nbytes():
raise Exception(
f"_ViewSpec sub-view extends beyond base. Sub-view needs {self.nbytes()} bytes at offset {byte_offset}, but base has {base.nbytes()} bytes."
)

def _run_guards(self) -> None:
unguarded_access = self._unguarded_access
Expand All @@ -211,14 +226,27 @@
"_guards",
"_unguarded_access",
"_run_guards",
"_byte_offset",
]:
return object.__getattribute__(self, name)

# Get some attributes from self
if name in self._self_fields:
val = object.__getattribute__(self, name)
elif name in self._base_fields:
val = object.__getattribute__(self._base, name)
val = getattr(self._base, name)
# For static sub-views, adjust mem_offset by byte_offset so the
# emitter can elide the op. For non-static sub-views, return
# None for mem_id/mem_offset — et_select runs at runtime and
# the output needs no allocation info.
if name in ("mem_id", "mem_offset") and val is not None:
byte_offset = object.__getattribute__(self, "_byte_offset")
if byte_offset != 0:
shape_dynamism = object.__getattribute__(self, "shape_dynamism")
if shape_dynamism != TensorShapeDynamism.STATIC:
val = None
elif name == "mem_offset":
val = val + byte_offset
else:
if len(name) > 0 and name[0] != "_":
logger.warning(
Expand All @@ -239,6 +267,7 @@
"_guards",
"_unguarded_access",
"_run_guards",
"_byte_offset",
]:
object.__setattr__(self, name, val)
return
Expand Down Expand Up @@ -293,11 +322,53 @@

n_replaced += 1

elif _is_select_copy(node) and all(
u.op != "output" for u in node.users
):
replaced = self._try_replace_select(node)
if replaced:
n_replaced += 1

module.recompile()

logger.debug(f"Replaced {n_replaced} view_copy nodes with {_VIEW_OP} nodes.")
logger.debug(
f"Replaced {n_replaced} select_copy nodes with {_SELECT_OP} nodes."
)
return PassResult(graph_module, n_replaced > 0)

def _try_replace_select(self, node: torch.fx.Node) -> bool:
base = node.args[0]
assert isinstance(base, torch.fx.Node)
base_spec = base.meta["spec"]

dim: int = node.args[1]
index: int = node.args[2]
base_shape = list(base_spec.shape)

if dim < 0:
dim += len(base_shape)

if any(base_shape[i] != 1 for i in range(dim)):
return False

if index < 0:
index += base_shape[dim]

base_stride = contiguous_stride_from_shape(torch.Size(base_shape))
element_size = torch._utils._element_size(base_spec.dtype)
byte_offset = index * base_stride[dim] * element_size

if base_spec.const:
return False

node.target = _SELECT_OP
view_spec = _ViewSpec(base_spec, node.meta["val"].shape, byte_offset)
if base_spec.shape_dynamism != TensorShapeDynamism.STATIC:
view_spec.shape_dynamism = base_spec.shape_dynamism
node.meta["spec"] = view_spec
return True

def ensures(self, graph_module: torch.fx.GraphModule) -> None:
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
Expand All @@ -310,6 +381,8 @@
)
if node.op == "call_function" and node.target == _VIEW_OP:
assert isinstance(node.meta["spec"], _ViewSpec)
if node.op == "call_function" and node.target == _SELECT_OP:
assert isinstance(node.meta["spec"], _ViewSpec)

def requires(self, graph_module: torch.fx.GraphModule) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions exir/serde/export_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _reverse_map(d: Dict[Any, Enum]):

_KNOWN_FUNCTIONS = {
exir.memory.view,
exir.memory.select,
}


Expand Down
1 change: 1 addition & 0 deletions exir/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def serialize(

_KNOWN_FUNCTIONS_MAP = {
"executorch.exir.memory.view": exir.memory.view,
"executorch.exir.memory.select": exir.memory.select,
}


Expand Down
Loading
Loading