-
Notifications
You must be signed in to change notification settings - Fork 108
Expand file tree
/
Copy pathremove_vmap.py
More file actions
86 lines (56 loc) · 2.15 KB
/
remove_vmap.py
File metadata and controls
86 lines (56 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# -*- coding: utf-8 -*-
import jax.numpy as jnp
import jax
if jax.__version__ >= '0.5.0':
from jax.extend.core import Primitive
else:
from jax.core import Primitive
from jax.core import ShapedArray
from jax.interpreters import batching, mlir, xla
from .ndarray import Array
__all__ = [
'remove_vmap'
]
def remove_vmap(x, op='any'):
if isinstance(x, Array):
x = x.value
if op == 'any':
return _any_without_vmap(x)
elif op == 'all':
return _all_without_vmap(x)
else:
raise ValueError(f'Do not support type: {op}')
_any_no_vmap_prim = Primitive('any_no_vmap')
def _any_without_vmap(x):
return _any_no_vmap_prim.bind(x)
def _any_without_vmap_imp(x):
return jnp.any(x)
def _any_without_vmap_abs(x):
return ShapedArray(shape=(), dtype=jnp.bool_)
def _any_without_vmap_batch(x, batch_axes):
(x, ) = x
return _any_without_vmap(x), batching.not_mapped
_any_no_vmap_prim.def_impl(_any_without_vmap_imp)
_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs)
batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch
if hasattr(xla, "lower_fun"):
xla.register_translation(_any_no_vmap_prim,
xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True))
mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False))
_all_no_vmap_prim = Primitive('all_no_vmap')
def _all_without_vmap(x):
return _all_no_vmap_prim.bind(x)
def _all_without_vmap_imp(x):
return jnp.all(x)
def _all_without_vmap_abs(x):
return ShapedArray(shape=(), dtype=jnp.bool_)
def _all_without_vmap_batch(x, batch_axes):
(x, ) = x
return _all_without_vmap(x), batching.not_mapped
_all_no_vmap_prim.def_impl(_all_without_vmap_imp)
_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs)
batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch
if hasattr(xla, "lower_fun"):
xla.register_translation(_all_no_vmap_prim,
xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True))
mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False))