forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathop_registry.py
More file actions
665 lines (553 loc) · 19 KB
/
op_registry.py
File metadata and controls
665 lines (553 loc) · 19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import operator
from typing import Any, Callable, Dict, List, Optional, Union
import executorch.backends.vulkan.custom_ops_lib # noqa
import executorch.backends.vulkan.utils as utils
import torch
from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch._subclasses.fake_tensor import FakeTensor
######################
## OpFeatures class ##
######################
def allow_node(node: torch.fx.Node) -> bool:
return True
class OpFeatures:
__slots__ = [
# Sets of possible (storage types, memory layouts) to use for the input tensor(s)
"inputs_storage",
# Sets of possible (storage types, memory layouts) to use for the output tensor(s)
"outputs_storage",
# bool indicating if the operator has a resize function, which allows it to
# support models with dynamic shape
"supports_resize",
# bool indicating if the operator handles its own prepacking. If this is True,
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
# of the op.
"supports_prepacking",
# Optional check function used during partitioning to determine if a node's
# inputs are supported by the operator implementation.
"are_node_inputs_supported_fn",
]
def __init__(
self,
inputs_storage: Optional[
Union[utils.TensorRepSet, List[utils.TensorRepSet]]
] = None,
outputs_storage: Optional[
Union[utils.TensorRepSet, List[utils.TensorRepSet]]
] = None,
supports_resize: bool = False,
supports_prepacking: bool = False,
are_node_inputs_supported_fn: Optional[Callable] = allow_node,
):
self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList(
inputs_storage if inputs_storage is not None else []
)
self.outputs_storage: utils.TensorRepSetList = utils.TensorRepSetList(
outputs_storage if outputs_storage is not None else []
)
# If output storage is not set, assume that it is derived from the first input
if self.outputs_storage.any_is_empty():
self.outputs_storage = utils.TensorRepSetList(self.inputs_storage[0])
self.supports_resize = supports_resize
self.supports_prepacking = supports_prepacking
self.are_node_inputs_supported_fn = are_node_inputs_supported_fn
def make_op_repsets(
self,
op_node: torch.fx.Node,
texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS,
) -> utils.OpRepSets:
return utils.OpRepSets(
self.inputs_storage, self.outputs_storage, op_node, texture_limits
)
#######################
## Operator Registry ##
#######################
OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload]
vulkan_supported_ops: Dict[OpKey, OpFeatures] = {}
def update_features(aten_op):
def features_decorator(fn: Callable):
def update_features_impl(op: OpKey):
if op in vulkan_supported_ops:
raise RuntimeError(f"[Vulkan delegate] duplicate registration of {op}!")
vulkan_supported_ops[op] = fn()
if isinstance(aten_op, list):
for op in aten_op:
update_features_impl(op)
else:
update_features_impl(aten_op)
return fn
return features_decorator
@update_features(
[
operator.getitem,
# Symbolic integer ops
torch.ops.aten.sym_size.int,
operator.add,
operator.lt,
operator.gt,
operator.ge,
operator.le,
operator.eq,
# Guard and assert ops
torch.ops.aten._assert_scalar.default,
torch.ops.aten.sym_constrain_range_for_size.default,
]
)
def register_ephemeral_op():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
)
@update_features(
[
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.quantize_per_token.default,
exir_ops.edge.quantized_decomposed.dequantize_per_token.default,
]
)
def register_quantization_op():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_BUFFER,
supports_resize=True,
)
@update_features(
[
exir_ops.edge.torchao.quantize_affine.default,
exir_ops.edge.torchao.dequantize_affine.default,
]
)
def register_affine_quantization_op():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_BUFFER,
supports_resize=True,
)
@update_features(
[
exir_ops.edge.torchao.choose_qparams_affine.default,
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default,
]
)
def register_torchao_quantization_op():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_BUFFER,
supports_resize=True,
)
@update_features(
[
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.div.Tensor_mode,
exir_ops.edge.aten.pow.Tensor_Tensor,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.lt.Tensor,
exir_ops.edge.aten.le.Tensor,
exir_ops.edge.aten.gt.Tensor,
exir_ops.edge.aten.ge.Tensor,
]
)
def register_binary_op():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
)
@update_features(
[
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.clamp.default,
exir_ops.edge.aten.cos.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.gelu.default,
exir_ops.edge.aten.hardshrink.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.neg.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.sin.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.round.default,
exir_ops.edge.aten.leaky_relu.default,
]
)
def register_unary_op():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
)
@update_features(exir_ops.edge.aten._to_copy.default)
def register_to_copy_op():
def check_to_copy_node(node: torch.fx.Node) -> bool:
float_dtypes = [torch.float16, torch.float32]
if len(node.args) != 1:
return False
in_arg = node.args[0]
if not isinstance(in_arg, torch.fx.Node):
return False
in_tensor = in_arg.meta.get("val", None)
out_tensor = node.meta.get("val", None)
if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor):
if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes:
return True
return False
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
are_node_inputs_supported_fn=check_to_copy_node,
)
@update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default)
def register_to_copy_dim_order_op():
# Currently there is no "real" implementation for to_dim_order_copy, but it can be
# removed as long as the operator is not changing the dtype, i.e. the operator call
# is modifying the dim order only. Therefore, check that the input and output dtypes
# are the same, if so the operator is safe to remove.
def check_dim_order_copy_node(node: torch.fx.Node) -> bool:
in_arg = node.args[0]
if not isinstance(in_arg, torch.fx.Node):
return False
in_tensor = in_arg.meta.get("val", None)
out_tensor = node.meta.get("val", None)
if in_tensor.dtype != out_tensor.dtype:
return False
return True
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
are_node_inputs_supported_fn=check_dim_order_copy_node,
)
@update_features(
[
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.mm.default,
exir_ops.edge.aten.addmm.default,
exir_ops.edge.aten.linear.default,
]
)
def register_mm_op():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
supports_resize=True,
supports_prepacking=True,
)
@update_features(
[
exir_ops.edge.aten._weight_int8pack_mm.default,
exir_ops.edge.et_vk.linear_qcs4w.default,
]
)
def register_int8_mm_op():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
supports_resize=True,
supports_prepacking=True,
)
@update_features(
[
exir_ops.edge.et_vk.linear_weight_int4.default,
]
)
def register_int4_mm_op():
return OpFeatures(
inputs_storage=utils.CONTIGUOUS_ANY,
supports_resize=True,
supports_prepacking=True,
)
@update_features(
[
exir_ops.edge.et_vk.linear_qta8a_qga4w.default,
]
)
def register_dqlinear_op():
return OpFeatures(
inputs_storage=[
utils.CONTIGUOUS_ANY, # input
utils.CONTIGUOUS_BUFFER, # mat1 scales
utils.CONTIGUOUS_BUFFER, # mat1 zeros
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # group size (non tensor)
utils.CONTIGUOUS_BUFFER, # mat2 scales
utils.CONTIGUOUS_BUFFER, # mat2 zeros
],
supports_resize=True,
supports_prepacking=True,
)
@update_features(
[
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten._softmax.default,
]
)
def register_softmax_op():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
supports_resize=True,
)
@update_features(
[
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.amax.default,
exir_ops.edge.aten.amin.default,
]
)
def register_reduce_op():
def check_reduce_node(node: torch.fx.Node) -> bool:
dim_list = node.args[1]
if isinstance(dim_list, list) and len(dim_list) > 2:
return False
if isinstance(dim_list, list) and len(dim_list) == 2:
# Try to get the memory layout for this node
try:
memory_layout = utils.get_node_memory_layout(node)
# If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
if (
memory_layout is not None
and memory_layout != VkMemoryLayout.DEFAULT_LAYOUT
):
# For now only default layout is supported for 2D reduction.
# Because we can't determine if the input is NCHW or NHWC here,
# assume the reduction dimension is packed so we cannot support it.
return False
except (AssertionError, KeyError, AttributeError):
# If we can't get memory layout information, we'll assume the dims aren't packed
pass
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
for arg in node.args:
if isinstance(arg, bool):
return arg
# Assume false by default
return False
keepdim = try_find_keepdim_arg(node)
if isinstance(keepdim, bool) and not keepdim:
return False
return True
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
supports_resize=True,
are_node_inputs_supported_fn=check_reduce_node,
)
@update_features(
[
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
]
)
def register_2d_pool_op():
return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
supports_resize=True,
)
@update_features(
[
exir_ops.edge.aten.convolution.default,
exir_ops.edge.et_vk.conv_with_clamp.default,
]
)
def register_convolution_op():
return OpFeatures(
inputs_storage=[
utils.CHANNELS_PACKED_TEXTURE, # input
utils.NO_STORAGE, # weight (prepacked)
utils.NO_STORAGE, # bias (prepacked)
utils.NO_STORAGE, # stride (non tensor)
utils.NO_STORAGE, # padding (non tensor)
utils.NO_STORAGE, # dilation (non tensor)
utils.NO_STORAGE, # transposed (non tensor)
utils.NO_STORAGE, # output_padding (non tensor)
utils.NO_STORAGE, # groups (non tensor)
utils.NO_STORAGE, # output_min (non tensor)
utils.NO_STORAGE, # output_max (non tensor)
],
supports_resize=True,
supports_prepacking=True,
)
@update_features("llama::sdpa_with_kv_cache")
def register_sdpa_with_kv_cache_op():
return OpFeatures(
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
supports_resize=True,
supports_prepacking=True,
)
@update_features(
[
"llama::update_cache",
"llama::custom_sdpa",
]
)
def register_sdpa_ops():
return OpFeatures(
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
supports_resize=True,
)
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
def register_rotary_emb_op():
return OpFeatures(
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
supports_resize=True,
)
@update_features(
[
exir_ops.edge.aten.permute.default,
]
)
def register_view_ops():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
supports_resize=True,
)
@update_features(
[
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.permute_copy.default,
]
)
def register_view_ops_with_buffer_meta():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
)
@update_features(exir_ops.edge.aten.expand_copy.default)
def register_expand():
return OpFeatures(inputs_storage=utils.ANY_BUFFER, supports_resize=False)
# Fully featured transfer operators (i.e. operators that copy data from the input
# tensor(s) to the output tensor(s)), which have memory layout agnostic implementations
# for both texture and buffer storage types.
@update_features(exir_ops.edge.aten.cat.default)
def register_cat_op():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
)
# Fully featured transfer operators (i.e. operators that copy data from the input
# tensor(s) to the output tensor(s)), which have memory layout agnostic implementations
# for both texture and buffer storage types.
@update_features(
[
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.slice_copy.Tensor,
]
)
def register_transfer_ops():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
supports_resize=True,
)
# Ops ported from PyTorch Vulkan backend. These ops commonly support channels
# packed tensors only and do not have a resize function.
@update_features(
[
# Shape Manipulation
exir_ops.edge.aten.t_copy.default,
# Indexing and lookup
exir_ops.edge.aten.flip.default,
exir_ops.edge.aten.index_select.default,
# Tensor creation
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.constant_pad_nd.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.full_like.default,
exir_ops.edge.aten.ones.default,
exir_ops.edge.aten.ones_like.default,
exir_ops.edge.aten.scalar_tensor.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.upsample_bilinear2d.vec,
exir_ops.edge.aten.zeros.default,
exir_ops.edge.aten.zeros_like.default,
exir_ops.edge.et_vk.grid_priors.default,
]
)
def register_ported_op():
return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
)
# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry because they support all packed dimensions
@update_features(
[
# Tensor combination
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split.Tensor,
]
)
def register_ported_op_all_packed_dims():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
)
# Ported ops that support their own prepacking.
@update_features(
[
exir_ops.edge.aten.embedding.default,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
]
)
def register_ported_ops_with_prepacking():
return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
supports_prepacking=True,
)
@update_features(
[
exir_ops.edge.aten.native_group_norm.default,
]
)
def register_native_group_norm():
return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
outputs_storage=[
utils.CHANNELS_PACKED_TEXTURE,
utils.CONTIGUOUS_BUFFER,
utils.CONTIGUOUS_BUFFER,
],
supports_prepacking=True,
)
# Ported ops that support their own prepacking.
@update_features(
[
exir_ops.edge.aten.native_layer_norm.default,
]
)
def register_ported_ops_with_prepacking_all_dims():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
supports_prepacking=True,
)
#######################
## Utility functions ##
#######################
def has_impl(target: Any) -> bool:
if not isinstance(target, str):
if target not in vulkan_supported_ops:
return target.name() in vulkan_supported_ops
return target in vulkan_supported_ops
else:
return target in vulkan_supported_ops
def get_op_features(target: Any) -> OpFeatures:
if not isinstance(target, str):
if target not in vulkan_supported_ops:
# Try the op's name
return vulkan_supported_ops[target.name()]
return vulkan_supported_ops[target]
else:
return vulkan_supported_ops[target]
def handles_own_prepacking(target: OpKey) -> bool:
return get_op_features(target).supports_prepacking