Skip to content

Commit 5494a67

Browse files
author
zhangyue
committed
feat(ascend): add embedding operator
1 parent 76094ad commit 5494a67

3 files changed

Lines changed: 217 additions & 0 deletions

File tree

src/base/embedding.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#ifndef INFINI_OPS_BASE_EMBEDDING_H_
2+
#define INFINI_OPS_BASE_EMBEDDING_H_
3+
4+
#include <cassert>
5+
6+
#include "data_type.h"
7+
#include "operator.h"
8+
9+
namespace infini::ops {
10+
11+
// Embedding performs a token embedding lookup.
12+
//
13+
// Interface follows the inference-time vLLM/PyTorch convention:
14+
// `out = weight[input_ids]`.
15+
//
16+
// The input layout is:
17+
// `input_ids`: Any shape, `int32` or `int64`.
18+
// `weight`: `[vocab_size, hidden_size]`.
19+
// `out`: `input_ids.shape + [hidden_size]`.
20+
//
21+
// This is the inference subset of `torch.nn.functional.embedding`; options
22+
// such as `padding_idx`, `max_norm`, `scale_grad_by_freq`, and `sparse` are
23+
// intentionally not part of this operator.
24+
class Embedding : public Operator<Embedding> {
25+
public:
26+
Embedding(const Tensor input_ids, const Tensor weight, Tensor out)
27+
: num_tokens_{input_ids.numel()},
28+
vocab_size_{weight.size(0)},
29+
hidden_size_{weight.size(1)},
30+
input_dtype_{input_ids.dtype()},
31+
weight_dtype_{weight.dtype()} {
32+
assert((input_dtype_ == DataType::kInt32 ||
33+
input_dtype_ == DataType::kInt64) &&
34+
"`Embedding` requires `input_ids` to be `int32` or `int64`");
35+
assert(
36+
weight.ndim() == 2 &&
37+
"`Embedding` requires `weight` to be 2D `[vocab_size, hidden_size]`");
38+
assert(out.dtype() == weight.dtype() &&
39+
"`Embedding` requires `out` and `weight` to have the same dtype");
40+
assert(out.ndim() == input_ids.ndim() + 1 &&
41+
"`Embedding` requires `out.ndim == input_ids.ndim + 1`");
42+
assert(out.size(-1) == hidden_size_ &&
43+
"`Embedding` requires `out.shape[-1] == weight.shape[-1]`");
44+
45+
for (std::size_t i = 0; i < input_ids.ndim(); ++i) {
46+
assert(out.size(i) == input_ids.size(i) &&
47+
"`Embedding` requires `out` prefix shape to match `input_ids`");
48+
}
49+
}
50+
51+
virtual void operator()(const Tensor input_ids, const Tensor weight,
52+
Tensor out) const = 0;
53+
54+
protected:
55+
Tensor::Size num_tokens_{0};
56+
57+
Tensor::Size vocab_size_{0};
58+
59+
Tensor::Size hidden_size_{0};
60+
61+
const DataType input_dtype_;
62+
63+
const DataType weight_dtype_;
64+
};
65+
66+
} // namespace infini::ops
67+
68+
#endif // INFINI_OPS_BASE_EMBEDDING_H_
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#ifndef INFINI_OPS_ASCEND_EMBEDDING_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_EMBEDDING_KERNEL_H_
3+
4+
#include <cassert>
5+
6+
#include "acl/acl.h"
7+
#include "aclnn/aclnn_base.h"
8+
#include "aclnnop/aclnn_embedding.h"
9+
#include "base/embedding.h"
10+
#include "native/ascend/common.h"
11+
#include "native/ascend/workspace_pool_.h"
12+
#include "operator.h"
13+
14+
namespace infini::ops {
15+
16+
template <>
17+
class Operator<Embedding, Device::Type::kAscend> : public Embedding {
18+
public:
19+
Operator(const Tensor input_ids, const Tensor weight, Tensor out)
20+
: Embedding(input_ids, weight, out),
21+
input_ids_cache_(input_ids),
22+
weight_cache_(weight),
23+
out_cache_(out) {
24+
assert((weight_dtype_ == DataType::kFloat16 ||
25+
weight_dtype_ == DataType::kBFloat16 ||
26+
weight_dtype_ == DataType::kFloat32) &&
27+
"`Embedding`: Ascend path supports `float16`, `bfloat16`, and "
28+
"`float32` weights");
29+
}
30+
31+
~Operator() {
32+
if (!ascend::IsAclRuntimeAlive()) return;
33+
34+
input_ids_cache_.release();
35+
weight_cache_.release();
36+
out_cache_.release();
37+
}
38+
39+
void operator()(const Tensor input_ids, const Tensor weight,
40+
Tensor out) const override {
41+
auto stream = static_cast<aclrtStream>(stream_);
42+
43+
auto t_weight = weight_cache_.get(const_cast<void*>(weight.data()));
44+
auto t_input_ids =
45+
input_ids_cache_.get(const_cast<void*>(input_ids.data()));
46+
auto t_out = out_cache_.get(out.data());
47+
48+
if (!executor_) {
49+
auto ret = aclnnEmbeddingGetWorkspaceSize(t_weight, t_input_ids, t_out,
50+
&ws_size_, &executor_);
51+
assert(ret == ACL_SUCCESS && "`aclnnEmbeddingGetWorkspaceSize` failed");
52+
aclSetAclOpExecutorRepeatable(executor_);
53+
} else {
54+
aclSetInputTensorAddr(executor_, 0, t_weight,
55+
const_cast<void*>(weight.data()));
56+
aclSetInputTensorAddr(executor_, 1, t_input_ids,
57+
const_cast<void*>(input_ids.data()));
58+
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
59+
}
60+
61+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
62+
auto ret = aclnnEmbedding(arena.buf, ws_size_, executor_, stream);
63+
assert(ret == ACL_SUCCESS && "`aclnnEmbedding` failed");
64+
}
65+
66+
private:
67+
mutable ascend::AclTensorCache input_ids_cache_;
68+
69+
mutable ascend::AclTensorCache weight_cache_;
70+
71+
mutable ascend::AclTensorCache out_cache_;
72+
73+
mutable aclOpExecutor* executor_ = nullptr;
74+
75+
mutable uint64_t ws_size_ = 0;
76+
};
77+
78+
} // namespace infini::ops
79+
80+
#endif // INFINI_OPS_ASCEND_EMBEDDING_KERNEL_H_

tests/test_embedding.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import infini.ops
2+
import pytest
3+
import torch
4+
5+
from tests.utils import Payload, get_stream
6+
7+
8+
@pytest.mark.auto_act_and_assert
9+
@pytest.mark.parametrize(
10+
"input_shape, vocab_size, hidden_size",
11+
(
12+
((5,), 17, 8),
13+
((2, 3), 23, 16),
14+
),
15+
)
16+
@pytest.mark.parametrize("index_dtype", (torch.int32, torch.int64))
17+
@pytest.mark.parametrize(
18+
("dtype", "rtol", "atol"),
19+
(
20+
(torch.float32, 0.0, 0.0),
21+
(torch.float16, 0.0, 0.0),
22+
(torch.bfloat16, 0.0, 0.0),
23+
),
24+
)
25+
def test_embedding(
26+
input_shape,
27+
vocab_size,
28+
hidden_size,
29+
index_dtype,
30+
implementation_index,
31+
dtype,
32+
device,
33+
rtol,
34+
atol,
35+
):
36+
input_ids = torch.randint(
37+
0, vocab_size, input_shape, dtype=index_dtype, device=device
38+
)
39+
weight = torch.randn((vocab_size, hidden_size), dtype=dtype, device=device)
40+
out = torch.empty((*input_shape, hidden_size), dtype=dtype, device=device)
41+
42+
return Payload(
43+
lambda *args, **kwargs: _embedding(
44+
*args, **kwargs, implementation_index=implementation_index
45+
),
46+
_ref_embedding,
47+
(input_ids, weight, out),
48+
{},
49+
rtol=rtol,
50+
atol=atol,
51+
)
52+
53+
54+
def _embedding(input_ids, weight, out, *, implementation_index=0):
55+
infini.ops.embedding(
56+
input_ids,
57+
weight,
58+
out,
59+
implementation_index=implementation_index,
60+
stream=get_stream(input_ids.device),
61+
)
62+
63+
return out
64+
65+
66+
def _ref_embedding(input_ids, weight, out):
67+
del out
68+
69+
return torch.nn.functional.embedding(input_ids.long(), weight)

0 commit comments

Comments
 (0)