Skip to content

Commit 8eacfef

Browse files
author
zhangyue
committed
fix(ascend): adopt PR #63/#60 master API — GetWorkspacePool/Ensure rename + drop registry.h (SFINAE autodetect)
1 parent f45f9da commit 8eacfef

34 files changed

Lines changed: 47 additions & 242 deletions

File tree

src/ascend/add/kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ class Operator<Add, Device::Type::kAscend> : public Add {
6565
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
6666
}
6767

68-
auto& arena = ascend::workspacePool().ensure(stream, ws_size_);
68+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
6969
aclnnAdd(arena.buf, ws_size_, executor_, stream);
7070
}
7171

src/ascend/add_rms_norm/kernel.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
#include "aclnn/aclnn_base.h"
88
#include "aclnn_add.h"
99
#include "aclnn_rms_norm.h"
10-
#include "ascend/add_rms_norm/registry.h"
1110
#include "ascend/common.h"
1211
#include "ascend/workspace_pool_.h"
12+
#include "base/add_rms_norm.h"
1313
#include "operator.h"
1414

1515
namespace infini::ops {
@@ -74,12 +74,12 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
7474
aclSetInputTensorAddr(add_exec_, 1, t_x2, const_cast<void*>(x2.data()));
7575
aclSetOutputTensorAddr(add_exec_, 0, t_x_out, x_out.data());
7676
}
77-
auto& add_arena = ascend::workspacePool().ensure(stream, add_ws_);
77+
auto& add_arena = ascend::GetWorkspacePool().Ensure(stream, add_ws_);
7878
aclnnAdd(add_arena.buf, add_ws_, add_exec_, stream);
7979

8080
// Obtain shared rstd buffer from pool.
8181
auto& rstd_arena =
82-
ascend::workspacePool().ensure(stream, rstd_size_, "temp");
82+
ascend::GetWorkspacePool().Ensure(stream, rstd_size_, "temp");
8383

8484
// Lazily create rstd tensor descriptor on first call.
8585
if (!rstd_tensor_) {
@@ -102,7 +102,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 0> : public AddRmsNorm {
102102
aclSetOutputTensorAddr(norm_exec_, 0, t_y_out, y_out.data());
103103
aclSetOutputTensorAddr(norm_exec_, 1, rstd_tensor_, rstd_arena.buf);
104104
}
105-
auto& norm_arena = ascend::workspacePool().ensure(stream, norm_ws_);
105+
auto& norm_arena = ascend::GetWorkspacePool().Ensure(stream, norm_ws_);
106106
aclnnRmsNorm(norm_arena.buf, norm_ws_, norm_exec_, stream);
107107
}
108108

src/ascend/add_rms_norm/kernel_custom.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include "acl/acl.h"
1111
#include "aclnn/aclnn_base.h"
1212
#include "aclnnop/aclnn_cast.h"
13-
#include "ascend/add_rms_norm/registry.h"
1413
#include "ascend/common.h"
1514
#include "ascend/workspace_pool_.h"
1615
#include "base/add_rms_norm.h"
@@ -121,7 +120,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 2> : public AddRmsNorm {
121120
aclSetOutputTensorAddr(cast_exec_, 0, t_dst, weight_fp32_data_);
122121
}
123122

124-
auto& arena = ascend::workspacePool().ensure(stream, cast_ws_);
123+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, cast_ws_);
125124
aclnnCast(arena.buf, cast_ws_, cast_exec_, stream);
126125
last_weight_ptr_ = cur_weight;
127126
}

src/ascend/add_rms_norm/kernel_fused.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
#include "acl/acl.h"
77
#include "aclnn/aclnn_base.h"
88
#include "aclnnop/aclnn_add_rms_norm.h"
9-
#include "ascend/add_rms_norm/registry.h"
109
#include "ascend/common.h"
1110
#include "ascend/workspace_pool_.h"
11+
#include "base/add_rms_norm.h"
1212
#include "operator.h"
1313

1414
namespace infini::ops {
@@ -98,7 +98,7 @@ class Operator<AddRmsNorm, Device::Type::kAscend, 1> : public AddRmsNorm {
9898
aclSetOutputTensorAddr(executor_, 2, t_x_out, x_out.data());
9999
}
100100

101-
auto& arena = ascend::workspacePool().ensure(stream, ws_size_);
101+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
102102
aclnnAddRmsNorm(arena.buf, ws_size_, executor_, stream);
103103
}
104104

src/ascend/add_rms_norm/registry.h

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/ascend/apply_rotary_pos_emb/kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class Operator<ApplyRotaryPosEmb, Device::Type::kAscend>
117117
aclSetInputTensorAddr(v2_exec_, 3, t_sin, const_cast<void*>(sin.data()));
118118
}
119119

120-
auto& arena = ascend::workspacePool().ensure(stream, v2_ws_);
120+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, v2_ws_);
121121
auto exec_ret =
122122
aclnnApplyRotaryPosEmbV2(arena.buf, v2_ws_, v2_exec_, stream);
123123
assert(exec_ret == 0 && "aclnnApplyRotaryPosEmbV2 failed");

src/ascend/apply_rotary_pos_emb/kernel_atb.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <vector>
1010

1111
#include "acl/acl.h"
12-
#include "ascend/apply_rotary_pos_emb/registry.h"
1312
#include "ascend/atb_common_.h"
1413
#include "ascend/common.h"
1514
#include "ascend/workspace_pool_.h"
@@ -141,7 +140,7 @@ class Operator<ApplyRotaryPosEmb, Device::Type::kAscend, 1>
141140
uint8_t* ws_ptr = nullptr;
142141

143142
if (ws_size > 0) {
144-
auto& arena = ascend::workspacePool().ensure(stream, ws_size);
143+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size);
145144
ws_ptr = static_cast<uint8_t*>(arena.buf);
146145
}
147146

src/ascend/apply_rotary_pos_emb/registry.h

Lines changed: 0 additions & 21 deletions
This file was deleted.

src/ascend/cast/kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class Operator<Cast, Device::Type::kAscend> : public Cast {
4343
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
4444
}
4545

46-
auto& arena = ascend::workspacePool().ensure(stream, ws_size_);
46+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
4747
aclnnCast(arena.buf, ws_size_, executor_, stream);
4848
}
4949

src/ascend/cat/kernel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class Operator<Cat, Device::Type::kAscend> : public Cat {
7777
aclSetOutputTensorAddr(executor_, 0, t_out, out.data());
7878
}
7979

80-
auto& arena = ascend::workspacePool().ensure(stream, ws_size_);
80+
auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_);
8181
aclnnCat(arena.buf, ws_size_, executor_, stream);
8282
}
8383

0 commit comments

Comments
 (0)