Skip to content

Commit 13fbb40

Browse files
committed
move llama.fallback out from model_sharding
1 parent 66fafc5 commit 13fbb40

3 files changed

Lines changed: 33 additions & 23 deletions

File tree

exir/passes/spec_prop_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import operator
1010
from typing import Optional
1111

12-
import executorch.extension.llm.custom_ops.model_sharding # noqa: F401
12+
# register llama.fallback
13+
import executorch.extension.llm.custom_ops.op_fallback # noqa: F401
1314

1415
import torch
1516
from executorch.exir.delegate import executorch_call_delegate

extension/llm/custom_ops/model_sharding.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import re
88
from typing import List
99

10-
import torch
10+
import executorch.extension.llm.custom_ops.op_fallback # noqa: F401
1111

12+
import torch
1213
from executorch.backends.qualcomm.utils.constants import (
1314
QCOM_PASS_ACTIVATE_KEY,
1415
QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY,
@@ -17,27 +18,6 @@
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.pass_base import ExportPass, PassResult
1920
from torch.export.exported_program import ExportedProgram
20-
from torch.library import impl, Library
21-
22-
23-
fallback_op_lib = Library("llama", "DEF")
24-
# registering an operator.
25-
fallback_op_lib.define("fallback(Tensor input) -> Tensor")
26-
27-
28-
@impl(fallback_op_lib, "fallback")
29-
def fallback_impl(a: torch.Tensor) -> torch.Tensor:
30-
return a
31-
32-
33-
# registering the out variant.
34-
fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)")
35-
36-
37-
@impl(fallback_op_lib, "fallback.out")
38-
def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
39-
out.copy_(a)
40-
return out
4121

4222

4323
class SplitGraph(ExportPass):
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# pyre-ignore-all-errors
7+
8+
import torch
9+
10+
from torch.library import impl, Library
11+
12+
fallback_op_lib = Library("llama", "DEF")
13+
# registering an operator.
14+
fallback_op_lib.define("fallback(Tensor input) -> Tensor")
15+
16+
17+
@impl(fallback_op_lib, "fallback")
18+
def fallback_impl(a: torch.Tensor) -> torch.Tensor:
19+
return a
20+
21+
22+
# registering the out variant.
23+
fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)")
24+
25+
26+
@impl(fallback_op_lib, "fallback.out")
27+
def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
28+
out.copy_(a)
29+
return out

0 commit comments

Comments
 (0)