File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 99import operator
1010from typing import Optional
1111
12- import executorch .extension .llm .custom_ops .model_sharding # noqa: F401
12+ # register llama.fallback
13+ import executorch .extension .llm .custom_ops .op_fallback # noqa: F401
1314
1415import torch
1516from executorch .exir .delegate import executorch_call_delegate
Original file line number Diff line number Diff line change 77import re
88from typing import List
99
10- import torch
10+ import executorch . extension . llm . custom_ops . op_fallback # noqa: F401
1111
12+ import torch
1213from executorch .backends .qualcomm .utils .constants import (
1314 QCOM_PASS_ACTIVATE_KEY ,
1415 QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY ,
1718from executorch .exir .dialects ._ops import ops as exir_ops
1819from executorch .exir .pass_base import ExportPass , PassResult
1920from torch .export .exported_program import ExportedProgram
20- from torch .library import impl , Library
21-
22-
23- fallback_op_lib = Library ("llama" , "DEF" )
24- # registering an operator.
25- fallback_op_lib .define ("fallback(Tensor input) -> Tensor" )
26-
27-
28- @impl (fallback_op_lib , "fallback" )
29- def fallback_impl (a : torch .Tensor ) -> torch .Tensor :
30- return a
31-
32-
33- # registering the out variant.
34- fallback_op_lib .define ("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)" )
35-
36-
37- @impl (fallback_op_lib , "fallback.out" )
38- def fallback_out_impl (a : torch .Tensor , * , out : torch .Tensor ) -> torch .Tensor :
39- out .copy_ (a )
40- return out
4121
4222
4323class SplitGraph (ExportPass ):
Original file line number Diff line number Diff line change 1+ # Copyright (c) Qualcomm Innovation Center, Inc.
2+ # All rights reserved
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+ # pyre-ignore-all-errors
7+
8+ import torch
9+
10+ from torch .library import impl , Library
11+
12+ fallback_op_lib = Library ("llama" , "DEF" )
13+ # registering an operator.
14+ fallback_op_lib .define ("fallback(Tensor input) -> Tensor" )
15+
16+
17+ @impl (fallback_op_lib , "fallback" )
18+ def fallback_impl (a : torch .Tensor ) -> torch .Tensor :
19+ return a
20+
21+
22+ # registering the out variant.
23+ fallback_op_lib .define ("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)" )
24+
25+
26+ @impl (fallback_op_lib , "fallback.out" )
27+ def fallback_out_impl (a : torch .Tensor , * , out : torch .Tensor ) -> torch .Tensor :
28+ out .copy_ (a )
29+ return out
You can’t perform that action at this time.
0 commit comments