-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathembedding_ops.py
More file actions
113 lines (94 loc) · 3.7 KB
/
embedding_ops.py
File metadata and controls
113 lines (94 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright 2024 RecML authors <recommendations-ml@google.com>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Embedding lookup ops."""
from collections.abc import Mapping, Sequence
import dataclasses
import functools
from typing import Any, TypeVar
from etils import epy
import jax
from jax.experimental import shard_map
with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from jax_tpu_embedding.sparsecore.lib.nn import embedding
# pylint: enable=g-import-not-at-top
T = TypeVar("T")
Nested = T | Sequence[T] | Mapping[str, T]
FeatureSpec = Any
@dataclasses.dataclass
class SparsecoreParams:
"""Embedding parameters."""
feature_specs: Nested[FeatureSpec]
mesh: jax.sharding.Mesh | jax.sharding.AbstractMesh
data_axes: Sequence[str | None]
embedding_axes: Sequence[str | None]
sharding_strategy: str
@functools.partial(jax.custom_vjp, nondiff_argnums=(0,))
def sparsecore_lookup(
sparsecore_params: SparsecoreParams,
tables: Mapping[str, tuple[jax.Array, ...]],
csr_inputs: tuple[jax.Array, ...],
):
return shard_map.shard_map(
functools.partial(
embedding.tpu_sparse_dense_matmul,
global_device_count=sparsecore_params.mesh.size,
feature_specs=sparsecore_params.feature_specs,
sharding_strategy=sparsecore_params.sharding_strategy,
),
mesh=sparsecore_params.mesh,
in_specs=(
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
),
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
check_rep=False,
)(csr_inputs, tables)
def _emb_lookup_fwd(
sparsecore_params: SparsecoreParams,
tables: Mapping[str, tuple[jax.Array, ...]],
csr_inputs: tuple[jax.Array, ...],
):
out = sparsecore_lookup(sparsecore_params, tables, csr_inputs)
return out, (tables, csr_inputs)
def _emb_lookup_bwd(
sparsecore_params: SparsecoreParams,
res: tuple[Mapping[str, tuple[jax.Array, ...]], tuple[jax.Array, ...]],
gradients: Nested[jax.Array],
) -> tuple[Nested[jax.Array], None]:
"""Backward pass for embedding lookup."""
(tables, csr_inputs) = res
emb_table_grads = shard_map.shard_map(
functools.partial(
embedding.tpu_sparse_dense_matmul_grad,
feature_specs=sparsecore_params.feature_specs,
sharding_strategy=sparsecore_params.sharding_strategy,
),
mesh=sparsecore_params.mesh,
in_specs=(
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
jax.sharding.PartitionSpec(*sparsecore_params.embedding_axes),
),
out_specs=jax.sharding.PartitionSpec(*sparsecore_params.data_axes),
check_rep=False,
)(gradients, csr_inputs, tables)
# `tpu_sparse_dense_matmul_grad` returns a general mapping (usually a dict).
# It may not be the same type as the embedding table (e.g. FrozenDict).
# Here we use flatten / unflatten to ensure the types are the same.
emb_table_grads = jax.tree.unflatten(
jax.tree.structure(tables), jax.tree.leaves(emb_table_grads)
)
return emb_table_grads, None
sparsecore_lookup.defvjp(_emb_lookup_fwd, _emb_lookup_bwd)