Skip to content

Commit eea8579

Browse files
author
zhangyue
committed
feat(ascend): add embedding operator
1 parent cc0bc83 commit eea8579

5 files changed

Lines changed: 258 additions & 12 deletions

File tree

.github/ci_config.yml

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ platforms:
1515
BASE_IMAGE: nvcr.io/nvidia/pytorch:24.10-py3
1616
SKIP_APT: "1"
1717
PIP_INDEX_URL: https://pypi.tuna.tsinghua.edu.cn/simple
18-
setup: pip install .[dev] --no-build-isolation
18+
setup: pip install .[dev] --no-build-isolation --config-settings=cmake.define.AUTO_DETECT_DEVICES=OFF --config-settings=cmake.define.WITH_CPU=ON --config-settings=cmake.define.WITH_NVIDIA=ON
1919
jobs:
2020
gpu:
2121
type: unittest
@@ -50,22 +50,21 @@ platforms:
5050
- /lib/firmware:/lib/firmware
5151
- /usr/src:/usr/src
5252
- /lib/modules:/lib/modules
53-
setup: python -m pip install packaging exceptiongroup typing-extensions pygments pybind11 libclang && python -m pip install . --no-build-isolation --no-deps
53+
setup: python -m pip install packaging exceptiongroup typing-extensions pygments pybind11 libclang && python -m pip install . --no-build-isolation --no-deps --config-settings=cmake.define.AUTO_DETECT_DEVICES=OFF --config-settings=cmake.define.WITH_CPU=ON --config-settings=cmake.define.WITH_ILUVATAR=ON
5454
jobs:
5555
gpu:
5656
type: unittest
5757
resources:
58-
gpu_ids: "0"
5958
ngpus: 1
6059
gpu_style: none
6160
memory: 32GB
6261
shm_size: 16g
63-
timeout: 7200
64-
queue_timeout: 7200
62+
timeout: 14400
63+
queue_timeout: 14400
6564
junit_path: test-results.xml
6665
stages:
6766
- name: test
68-
run: pytest tests/ --devices iluvatar -n 4 -v --tb=short --junitxml=/workspace/results/test-results.xml
67+
run: pytest tests/ --devices iluvatar -n 2 -v --tb=short --junitxml=/workspace/results/test-results.xml
6968

7069
metax:
7170
runner_label: Metax
@@ -80,7 +79,7 @@ platforms:
8079
- "--privileged"
8180
- "--ulimit=memlock=-1"
8281
- "--ulimit=stack=67108864"
83-
setup: pip install .[dev] --no-build-isolation
82+
setup: pip install .[dev] --no-build-isolation --config-settings=cmake.define.AUTO_DETECT_DEVICES=OFF --config-settings=cmake.define.WITH_CPU=ON --config-settings=cmake.define.WITH_METAX=ON
8483
jobs:
8584
gpu:
8685
type: unittest
@@ -107,7 +106,7 @@ platforms:
107106
PIP_INDEX_URL: https://pypi.org/simple
108107
docker_args:
109108
- "--privileged"
110-
setup: pip install .[dev] --no-build-isolation
109+
setup: pip install .[dev] --no-build-isolation --config-settings=cmake.define.AUTO_DETECT_DEVICES=OFF --config-settings=cmake.define.WITH_CPU=ON --config-settings=cmake.define.WITH_MOORE=ON
111110
jobs:
112111
gpu:
113112
type: unittest
@@ -133,7 +132,7 @@ platforms:
133132
PIP_INDEX_URL: https://pypi.org/simple
134133
docker_args:
135134
- "--privileged"
136-
setup: pip install .[dev] --no-build-isolation
135+
setup: pip install .[dev] --no-build-isolation --config-settings=cmake.define.AUTO_DETECT_DEVICES=OFF --config-settings=cmake.define.WITH_CPU=ON --config-settings=cmake.define.WITH_CAMBRICON=ON
137136
jobs:
138137
gpu:
139138
type: unittest
@@ -168,7 +167,7 @@ platforms:
168167
- "--group-add=video"
169168
volumes:
170169
- /opt/hyhal:/opt/hyhal:ro
171-
setup: pip install .[dev] --no-build-isolation
170+
setup: pip install .[dev] --no-build-isolation --config-settings=cmake.define.AUTO_DETECT_DEVICES=OFF --config-settings=cmake.define.WITH_CPU=ON --config-settings=cmake.define.WITH_HYGON=ON
172171
jobs:
173172
gpu:
174173
type: unittest
@@ -205,7 +204,7 @@ platforms:
205204
- /usr/local/bin/npu-smi:/usr/local/bin/npu-smi:ro
206205
env:
207206
ASCEND_HOME_PATH: /usr/local/Ascend/ascend-toolkit/latest
208-
setup: pip install .[dev] --no-build-isolation
207+
setup: pip install .[dev] --no-build-isolation --config-settings=cmake.define.AUTO_DETECT_DEVICES=OFF --config-settings=cmake.define.WITH_CPU=ON --config-settings=cmake.define.WITH_ASCEND=ON
209208
jobs:
210209
npu:
211210
type: unittest

src/CMakeLists.txt

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,10 +542,39 @@ if(GENERATE_OPERATOR_CALL_INSTANTIATIONS)
542542
file(GLOB_RECURSE OPERATOR_CALL_INSTANTIATION_SOURCES CONFIGURE_DEPENDS
543543
"${PROJECT_SOURCE_DIR}/generated/src/operator_call_instantiations_*.cc")
544544

545+
set(_operator_call_instantiation_job_pool_arg)
546+
if(WITH_TORCH AND CMAKE_GENERATOR MATCHES "Ninja")
547+
set(INFINIOPS_OPERATOR_CALL_INSTANTIATION_COMPILE_JOBS "2" CACHE STRING
548+
"Maximum concurrent generated operator call instantiation compilations")
549+
set_property(GLOBAL APPEND PROPERTY JOB_POOLS
550+
operator_call_instantiation_compile=${INFINIOPS_OPERATOR_CALL_INSTANTIATION_COMPILE_JOBS})
551+
set(_operator_call_instantiation_job_pool_arg
552+
JOB_POOL operator_call_instantiation_compile)
553+
endif()
554+
545555
if(WITH_NVIDIA OR WITH_HYGON)
546556
set_source_files_properties(${OPERATOR_CALL_INSTANTIATION_SOURCES}
547557
PROPERTIES LANGUAGE CUDA)
548-
target_sources(infiniops PRIVATE ${OPERATOR_CALL_INSTANTIATION_SOURCES})
558+
if(WITH_TORCH AND CMAKE_GENERATOR MATCHES "Ninja")
559+
add_library(infiniops_operator_call_instantiation_objs OBJECT
560+
${OPERATOR_CALL_INSTANTIATION_SOURCES})
561+
set_target_properties(infiniops_operator_call_instantiation_objs
562+
PROPERTIES
563+
CUDA_STANDARD 17
564+
CUDA_STANDARD_REQUIRED ON
565+
JOB_POOL_COMPILE operator_call_instantiation_compile
566+
POSITION_INDEPENDENT_CODE ON)
567+
target_include_directories(infiniops_operator_call_instantiation_objs PRIVATE
568+
$<TARGET_PROPERTY:infiniops,INCLUDE_DIRECTORIES>)
569+
target_compile_definitions(infiniops_operator_call_instantiation_objs PRIVATE
570+
$<TARGET_PROPERTY:infiniops,COMPILE_DEFINITIONS>)
571+
target_compile_options(infiniops_operator_call_instantiation_objs PRIVATE
572+
$<TARGET_PROPERTY:infiniops,COMPILE_OPTIONS>)
573+
target_sources(infiniops PRIVATE
574+
$<TARGET_OBJECTS:infiniops_operator_call_instantiation_objs>)
575+
else()
576+
target_sources(infiniops PRIVATE ${OPERATOR_CALL_INSTANTIATION_SOURCES})
577+
endif()
549578
elseif(WITH_ILUVATAR)
550579
set(_iluvatar_call_instantiation_include_flags
551580
"-I${CMAKE_CURRENT_SOURCE_DIR}"
@@ -591,6 +620,7 @@ if(GENERATE_OPERATOR_CALL_INSTANTIATIONS)
591620
-c "${_src}" -o "${_obj}"
592621
DEPENDS "${_src}"
593622
${_depfile_arg}
623+
${_operator_call_instantiation_job_pool_arg}
594624
COMMENT "Compiling ${_name}.cc with CoreX clang++"
595625
VERBATIM
596626
)

src/base/embedding.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#ifndef INFINI_OPS_BASE_EMBEDDING_H_
2+
#define INFINI_OPS_BASE_EMBEDDING_H_
3+
4+
#include <cassert>
5+
6+
#include "data_type.h"
7+
#include "operator.h"
8+
9+
namespace infini::ops {
10+
11+
// Embedding performs a token embedding lookup.
12+
//
13+
// Interface follows the inference-time vLLM/PyTorch convention:
14+
// `out = weight[input_ids]`.
15+
//
16+
// The input layout is:
17+
// `input_ids`: Any shape, `int32` or `int64`.
18+
// `weight`: `[vocab_size, hidden_size]`.
19+
// `out`: `input_ids.shape + [hidden_size]`.
20+
//
21+
// This is the inference subset of `torch.nn.functional.embedding`; options
22+
// such as `padding_idx`, `max_norm`, `scale_grad_by_freq`, and `sparse` are
23+
// intentionally not part of this operator.
24+
class Embedding : public Operator<Embedding> {
25+
public:
26+
Embedding(const Tensor input_ids, const Tensor weight, Tensor out)
27+
: num_tokens_{input_ids.numel()},
28+
vocab_size_{weight.size(0)},
29+
hidden_size_{weight.size(1)},
30+
input_dtype_{input_ids.dtype()},
31+
weight_dtype_{weight.dtype()} {
32+
assert((input_dtype_ == DataType::kInt32 ||
33+
input_dtype_ == DataType::kInt64) &&
34+
"`Embedding` requires `input_ids` to be `int32` or `int64`");
35+
assert(
36+
weight.ndim() == 2 &&
37+
"`Embedding` requires `weight` to be 2D `[vocab_size, hidden_size]`");
38+
assert(out.dtype() == weight.dtype() &&
39+
"`Embedding` requires `out` and `weight` to have the same dtype");
40+
assert(out.ndim() == input_ids.ndim() + 1 &&
41+
"`Embedding` requires `out.ndim == input_ids.ndim + 1`");
42+
assert(out.size(-1) == hidden_size_ &&
43+
"`Embedding` requires `out.shape[-1] == weight.shape[-1]`");
44+
45+
for (std::size_t i = 0; i < input_ids.ndim(); ++i) {
46+
assert(out.size(i) == input_ids.size(i) &&
47+
"`Embedding` requires `out` prefix shape to match `input_ids`");
48+
}
49+
}
50+
51+
virtual void operator()(const Tensor input_ids, const Tensor weight,
52+
Tensor out) const = 0;
53+
54+
protected:
55+
Tensor::Size num_tokens_{0};
56+
57+
Tensor::Size vocab_size_{0};
58+
59+
Tensor::Size hidden_size_{0};
60+
61+
const DataType input_dtype_;
62+
63+
const DataType weight_dtype_;
64+
};
65+
66+
} // namespace infini::ops
67+
68+
#endif // INFINI_OPS_BASE_EMBEDDING_H_
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#ifndef INFINI_OPS_ASCEND_EMBEDDING_KERNEL_H_
2+
#define INFINI_OPS_ASCEND_EMBEDDING_KERNEL_H_
3+
4+
#include <cassert>
5+
6+
#include "acl/acl.h"
7+
#include "aclnn/aclnn_base.h"
8+
#include "aclnnop/aclnn_embedding.h"
9+
#include "base/embedding.h"
10+
#include "native/ascend/common.h"
11+
#include "native/ascend/workspace_pool_.h"
12+
#include "operator.h"
13+
14+
namespace infini::ops {
15+
16+
template <>
17+
class Operator<Embedding, Device::Type::kAscend> : public Embedding {
18+
public:
19+
Operator(const Tensor input_ids, const Tensor weight, Tensor out)
20+
: Embedding(input_ids, weight, out),
21+
input_ids_cache_(input_ids),
22+
weight_cache_(weight),
23+
out_cache_(out) {
24+
assert((weight_dtype_ == DataType::kFloat16 ||
25+
weight_dtype_ == DataType::kBFloat16 ||
26+
weight_dtype_ == DataType::kFloat32) &&
27+
"`Embedding`: Ascend path supports `float16`, `bfloat16`, and "
28+
"`float32` weights");
29+
}
30+
31+
~Operator() {
32+
if (!ascend::IsAclRuntimeAlive()) return;
33+
34+
input_ids_cache_.release();
35+
weight_cache_.release();
36+
out_cache_.release();
37+
}
38+
39+
void operator()(const Tensor input_ids, const Tensor weight,
40+
Tensor out) const override {
41+
auto stream = static_cast<aclrtStream>(stream_);
42+
43+
auto t_weight = weight_cache_.get(const_cast<void*>(weight.data()));
44+
auto t_input_ids =
45+
input_ids_cache_.get(const_cast<void*>(input_ids.data()));
46+
auto t_out = out_cache_.get(out.data());
47+
48+
if (!executor_) {
49+
auto ret = aclnnEmbeddingGetWorkspaceSize(t_weight, t_input_ids, t_out,
50+
&ws_size_, &executor_);
51+
assert(ret == ACL_SUCCESS && "`aclnnEmbeddingGetWorkspaceSize` failed");
52+
aclSetAclOpExecutorRepeatable(executor_);
53+
} else {
54+
aclSetInputTensorAddr(executor_, 0, t_weight,
55+
const_cast<void*>(weight.data()));
56+
aclSetInputTensorAddr(executor_, 1, t_input_ids,
57+
const_cast<void*>(input_ids.data()));
58+
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
59+
}
60+
61+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
62+
auto ret = aclnnEmbedding(arena.buf, ws_size_, executor_, stream);
63+
assert(ret == ACL_SUCCESS && "`aclnnEmbedding` failed");
64+
}
65+
66+
private:
67+
mutable ascend::AclTensorCache input_ids_cache_;
68+
69+
mutable ascend::AclTensorCache weight_cache_;
70+
71+
mutable ascend::AclTensorCache out_cache_;
72+
73+
mutable aclOpExecutor* executor_ = nullptr;
74+
75+
mutable uint64_t ws_size_ = 0;
76+
};
77+
78+
} // namespace infini::ops
79+
80+
#endif // INFINI_OPS_ASCEND_EMBEDDING_KERNEL_H_

tests/test_embedding.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import infini.ops
2+
import pytest
3+
import torch
4+
5+
from tests.utils import Payload, get_stream
6+
7+
8+
@pytest.mark.auto_act_and_assert
9+
@pytest.mark.parametrize(
10+
"input_shape, vocab_size, hidden_size",
11+
(
12+
((5,), 17, 8),
13+
((2, 3), 23, 16),
14+
),
15+
)
16+
@pytest.mark.parametrize("index_dtype", (torch.int32, torch.int64))
17+
@pytest.mark.parametrize(
18+
("dtype", "rtol", "atol"),
19+
(
20+
(torch.float32, 0.0, 0.0),
21+
(torch.float16, 0.0, 0.0),
22+
(torch.bfloat16, 0.0, 0.0),
23+
),
24+
)
25+
def test_embedding(
26+
input_shape,
27+
vocab_size,
28+
hidden_size,
29+
index_dtype,
30+
implementation_index,
31+
dtype,
32+
device,
33+
rtol,
34+
atol,
35+
):
36+
input_ids = torch.randint(
37+
0, vocab_size, input_shape, dtype=index_dtype, device=device
38+
)
39+
weight = torch.randn((vocab_size, hidden_size), dtype=dtype, device=device)
40+
out = torch.empty((*input_shape, hidden_size), dtype=dtype, device=device)
41+
42+
return Payload(
43+
lambda *args, **kwargs: _embedding(
44+
*args, **kwargs, implementation_index=implementation_index
45+
),
46+
_ref_embedding,
47+
(input_ids, weight, out),
48+
{},
49+
rtol=rtol,
50+
atol=atol,
51+
)
52+
53+
54+
def _embedding(input_ids, weight, out, *, implementation_index=0):
55+
infini.ops.embedding(
56+
input_ids,
57+
weight,
58+
out,
59+
implementation_index=implementation_index,
60+
stream=get_stream(input_ids.device),
61+
)
62+
63+
return out
64+
65+
66+
def _ref_embedding(input_ids, weight, out):
67+
del out
68+
69+
return torch.nn.functional.embedding(input_ids.long(), weight)

0 commit comments

Comments
 (0)